Source code for torchref.refinement.targets.occupancy_floor_diagnostic

"""
Occupancy Floor Diagnostic for Time-Resolved Crystallography.

This module provides tools to estimate a lower bound on the activation fraction
by analyzing electron density. The key insight is that negative electron density
is unphysical - you cannot remove more electrons than were present.

For atoms that move in the excited state (e.g., waters, ligands):
- The dark state has density ρ_dark at the original position
- If the atom completely leaves, the light state has ρ_light ≈ 0 there
- The observed difference is: Δρ = α × (ρ_light - ρ_dark) = -α × ρ_dark
- The depth of the negative peak gives: α = |Δρ| / ρ_dark

If α is underestimated, the model must predict ρ_light < 0 to fit the data,
which is unphysical. This provides a floor on α.
"""

import torch
from torch import nn
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

from torchref.utils.device_mixin import DeviceMixin

if TYPE_CHECKING:
    from torchref.model import ModelFT, MixedModel


[docs] class OccupancyFloorDiagnostic: """ Diagnostic tool to estimate activation fraction floor from electron density. Analyzes the electron density of the light/refined model and checks for unphysical negative density, which indicates the activation fraction is too small. Parameters ---------- model_dark : ModelFT The dark/ground state model. model_light : ModelFT The light/excited state model (the refined one, not MixedModel). grid_spacing : float, optional Grid spacing in Angstroms for density calculation. Default is 0.5. negative_threshold : float, optional Threshold below which density is considered "significantly negative". Default is -0.5 (in sigma units after normalization). Examples -------- Basic usage:: diagnostic = OccupancyFloorDiagnostic(model_dark, model_light_refine) result = diagnostic.analyze() print(f"Estimated alpha floor: {result['alpha_floor']:.3f}") """
[docs] def __init__( self, model_dark: "ModelFT", model_light: "ModelFT", grid_spacing: float = 0.5, negative_threshold: float = -0.5, ): self.model_dark = model_dark self.model_light = model_light self.grid_spacing = grid_spacing self.negative_threshold = negative_threshold
[docs] def compute_density_at_positions( self, model: "ModelFT", positions: torch.Tensor, hkl: torch.Tensor, ) -> torch.Tensor: """ Compute electron density at specific positions using Fourier summation. This is a simplified calculation that sums F_calc * exp(2πi * h·r). Parameters ---------- model : ModelFT Model to compute density from. positions : torch.Tensor Positions in fractional coordinates, shape (N, 3). hkl : torch.Tensor Miller indices, shape (M, 3). Returns ------- torch.Tensor Electron density values at each position, shape (N,). """ with torch.no_grad(): fcalc = model(hkl, recalc=True) # Compute h·r for all position-reflection pairs # positions: (N, 3), hkl: (M, 3) # h_dot_r: (N, M) h_dot_r = torch.matmul(positions, hkl.T.float()) # Fourier sum: ρ(r) = Σ_h F(h) * exp(2πi * h·r) # For real density, this is: Σ_h |F(h)| * cos(2π*h·r + φ(h)) phase = torch.angle(fcalc) # (M,) amplitude = torch.abs(fcalc) # (M,) # ρ(r) = Σ_h |F(h)| * cos(2π*h·r + φ(h)) density = (amplitude.unsqueeze(0) * torch.cos(2 * torch.pi * h_dot_r + phase.unsqueeze(0))).sum(dim=1) # Normalize by number of reflections (approximate) density = density / len(hkl) return density
[docs] def analyze_at_dark_positions( self, hkl: torch.Tensor, atom_mask: Optional[torch.Tensor] = None, ) -> Dict: """ Analyze light model density at dark atom positions. Parameters ---------- hkl : torch.Tensor Miller indices for Fourier calculation. atom_mask : torch.Tensor, optional Boolean mask selecting which atoms to analyze (e.g., waters only). Returns ------- dict Dictionary with analysis results including: - 'rho_dark': Dark model density at atom positions - 'rho_light': Light model density at atom positions - 'rho_ratio': ρ_light / ρ_dark (should be ≥ 0) - 'negative_mask': Boolean mask of atoms with negative light density - 'alpha_floor': Estimated lower bound on activation fraction - 'worst_atoms': Indices of atoms with most negative density """ # Get atom positions in fractional coordinates xyz_dark = self.model_dark.xyz() cell = self.model_dark.cell # Convert to fractional coordinates frac_dark = cell.cartesian_to_fractional(xyz_dark) if atom_mask is not None: frac_dark = frac_dark[atom_mask] # Compute density at dark positions for both models rho_dark = self.compute_density_at_positions(self.model_dark, frac_dark, hkl) rho_light = self.compute_density_at_positions(self.model_light, frac_dark, hkl) # Find atoms where light density is negative negative_mask = rho_light < 0 # Compute density ratio (avoiding division by zero) rho_ratio = rho_light / (rho_dark + 1e-6) # Estimate alpha floor from the most negative cases # If ρ_light < 0 and we need ρ_light ≥ 0, then: # The model is predicting: ρ_light_model = ρ_dark + (1/α) * Δρ_obs # For this to be ≥ 0: α ≥ |Δρ_obs| / ρ_dark # # From the current (wrong) model: ρ_light_wrong < 0 # This means the model's α is too small # # The minimum valid α would make ρ_light = 0 at these positions # α_min = Δρ_obs / ρ_dark where Δρ_obs comes from the data # # As a proxy, if ρ_light_model < 0, the ratio |ρ_light_model|/ρ_dark # tells us roughly how much α needs to increase if negative_mask.any(): # For atoms with negative light density rho_light_neg = rho_light[negative_mask] rho_dark_at_neg = rho_dark[negative_mask] # The "missing" density that would need to be added # to make light density non-negative missing = -rho_light_neg # Rough estimate: if current α gives negative density, # we need α to be larger by factor of roughly (1 + missing/rho_dark) # This is a heuristic, not exact correction_factor = (missing / (rho_dark_at_neg + 1e-6)).max() # Find worst atoms worst_idx = torch.argsort(rho_light)[:5] # 5 most negative else: correction_factor = torch.tensor(0.0) worst_idx = torch.tensor([]) return { 'rho_dark': rho_dark, 'rho_light': rho_light, 'rho_ratio': rho_ratio, 'negative_mask': negative_mask, 'n_negative': negative_mask.sum().item(), 'n_total': len(rho_light), 'fraction_negative': negative_mask.float().mean().item(), 'min_rho_light': rho_light.min().item(), 'correction_factor': correction_factor.item(), 'worst_atoms': worst_idx, }
[docs] def estimate_alpha_floor_from_difference_map( self, hkl: torch.Tensor, delta_F_obs: torch.Tensor, sigma_diff: torch.Tensor, n_peaks: int = 10, sigma_cutoff: float = 3.0, ) -> Dict: """ Estimate alpha floor from significant negative peaks in difference map. For each significant negative peak in the difference map, estimate the minimum α that could produce that peak without requiring negative density in the light state. Parameters ---------- hkl : torch.Tensor Miller indices. delta_F_obs : torch.Tensor Observed difference amplitudes (can be negative). sigma_diff : torch.Tensor Uncertainties on difference amplitudes. n_peaks : int, optional Number of peaks to analyze. Default is 10. sigma_cutoff : float, optional Minimum significance (|ΔF|/σ) for peaks. Default is 3.0. Returns ------- dict Dictionary with alpha floor estimates. """ # Find significant negative differences significance = delta_F_obs / sigma_diff negative_sig = significance < -sigma_cutoff if not negative_sig.any(): return { 'alpha_floor': 0.0, 'message': 'No significant negative peaks found', 'n_negative_peaks': 0, } # Get the most negative peaks neg_indices = torch.where(negative_sig)[0] neg_values = delta_F_obs[neg_indices] sorted_idx = torch.argsort(neg_values)[:n_peaks] peak_indices = neg_indices[sorted_idx] # For each peak, estimate required α # The negative ΔF comes from atoms leaving their dark positions # |ΔF_neg| ≈ α × F_dark_contribution # So α ≈ |ΔF_neg| / F_dark_contribution # Compute F_dark at these reflections with torch.no_grad(): fcalc_dark = self.model_dark(hkl, recalc=True) F_dark = torch.abs(fcalc_dark) # Estimate alpha from each peak neg_dF = torch.abs(delta_F_obs[peak_indices]) F_dark_at_peaks = F_dark[peak_indices] # α ≈ |ΔF| / F_dark (rough estimate) alpha_estimates = neg_dF / (F_dark_at_peaks + 1e-6) return { 'alpha_floor': alpha_estimates.max().item(), 'alpha_estimates': alpha_estimates.tolist(), 'peak_indices': peak_indices.tolist(), 'peak_dF': neg_dF.tolist(), 'n_negative_peaks': len(peak_indices), }
[docs] class NegativeDensityPenalty(DeviceMixin, nn.Module): """ Loss term that penalizes negative electron density in the MIXED model. This provides a soft constraint that prevents the activation fraction from being too small (which would require unphysical negative density). The key insight: the MIXED state (not pure light) should have non-negative density everywhere. If α is too small and atoms have moved, the mixed model might predict negative density at some positions, which is unphysical. Parameters ---------- mixed_model : MixedModel The mixed model (combines dark and light states with fractions). model_dark : ModelFT The dark/ground state model (provides reference positions to check). hkl : torch.Tensor Miller indices for density calculation. atom_mask : torch.Tensor, optional Mask selecting which atoms to monitor. check_grid : bool, optional If True, also check density on a grid (more thorough but slower). Default is False. """
[docs] def __init__( self, mixed_model: "MixedModel", model_dark: "ModelFT", hkl: torch.Tensor, atom_mask: Optional[torch.Tensor] = None, check_grid: bool = False, ): super().__init__() self.mixed_model = mixed_model self.model_dark = model_dark self.register_buffer('hkl', hkl) self.atom_mask = atom_mask self.check_grid = check_grid # Pre-compute dark positions in fractional coordinates with torch.no_grad(): xyz_dark = model_dark.xyz() cell = model_dark.cell frac_dark = cell.cartesian_to_fractional(xyz_dark) if atom_mask is not None: frac_dark = frac_dark[atom_mask] self.register_buffer('frac_positions', frac_dark)
[docs] def forward(self) -> torch.Tensor: """ Compute penalty for negative density in mixed model. Returns ------- torch.Tensor Scalar penalty value (0 if no negative density). """ # Get mixed model structure factors (includes the α weighting) fcalc_mixed = self.mixed_model(self.hkl, recalc=True) # Compute density at dark atom positions h_dot_r = torch.matmul(self.frac_positions, self.hkl.T.float()) phase = torch.angle(fcalc_mixed) amplitude = torch.abs(fcalc_mixed) density = (amplitude.unsqueeze(0) * torch.cos(2 * torch.pi * h_dot_r + phase.unsqueeze(0))).sum(dim=1) density = density / len(self.hkl) # Penalize negative density (ReLU-like penalty) # Use a soft margin to avoid penalizing small numerical fluctuations margin = 0.1 # Small positive margin negative_density = torch.relu(-(density - margin)) # Return mean squared negative density return (negative_density ** 2).mean()
[docs] class DisplacementRegularizer(DeviceMixin, nn.Module): """ Regularizer that penalizes large atomic displacements from reference structure. This directly breaks the α-δF degeneracy by favoring solutions where atoms haven't moved too far from the dark structure, which implies larger α. The loss is: mean((xyz_light - xyz_dark)²) Parameters ---------- model_light : ModelFT The light model being refined. model_dark : ModelFT The dark reference model (frozen). atom_mask : torch.Tensor, optional Boolean mask selecting which atoms to include. max_displacement : float, optional Maximum allowed displacement in Angstroms. Displacements beyond this are penalized quadratically. Default is 2.0 Å. """
[docs] def __init__( self, model_light: "ModelFT", model_dark: "ModelFT", atom_mask: Optional[torch.Tensor] = None, max_displacement: float = 2.0, ): super().__init__() self.model_light = model_light self.model_dark = model_dark self.atom_mask = atom_mask self.max_displacement = max_displacement # Store reference dark positions with torch.no_grad(): xyz_dark = model_dark.xyz() if atom_mask is not None: xyz_dark = xyz_dark[atom_mask] self.register_buffer('xyz_dark_ref', xyz_dark.clone())
[docs] def forward(self) -> torch.Tensor: """ Compute displacement penalty. Returns ------- torch.Tensor Mean squared displacement penalty. """ xyz_light = self.model_light.xyz() if self.atom_mask is not None: xyz_light = xyz_light[self.atom_mask] # Compute per-atom displacement displacement = xyz_light - self.xyz_dark_ref dist_sq = (displacement ** 2).sum(dim=1) # Penalize ALL movement proportionally (mean squared displacement) # This directly favors smaller movements → larger α return dist_sq.mean()
[docs] class DifferenceAmplitudeRegularizer(DeviceMixin, nn.Module): """ Regularizer that encourages consistency between α and difference amplitudes. The key insight: the ratio of calculated to observed difference amplitudes should be consistent. If α is too small, the model compensates by making larger structural changes, which changes this ratio in a detectable way. This regularizer penalizes deviations from the expected relationship: |ΔF_calc| ≈ |ΔF_obs| When α is correct and the structure is correct, these should match. When α is too small and structure has moved too far, the pattern of |ΔF_calc| vs |ΔF_obs| will be distorted. Parameters ---------- dataset_collection : DatasetCollection Collection with 'dark' and 'light' datasets. mixed_model : MixedModel The mixed model being refined. model_dark : ModelFT The dark reference model. """
[docs] def __init__( self, dataset_collection, mixed_model: "MixedModel", model_dark: "ModelFT", ): super().__init__() self.mixed_model = mixed_model self.model_dark = model_dark self._dataset_collection = dataset_collection # Get observed amplitudes _, F_light, _, _ = dataset_collection['light']() _, F_dark, _, _ = dataset_collection['dark']() if hasattr(F_light, "get_data"): F_light = F_light.get_data() if hasattr(F_dark, "get_data"): F_dark = F_dark.get_data() self.register_buffer('_dF_obs', F_light - F_dark) self.register_buffer('_F_obs_dark', F_dark)
@property def hkl(self): return self._dataset_collection.hkl
[docs] def forward(self) -> torch.Tensor: """ Compute regularization loss. Penalizes the variance in the ratio |ΔF_calc|/|ΔF_obs|. If α and structure are correct, this ratio should be ~1 everywhere. If α is wrong, this ratio will have high variance. """ hkl = self.hkl fcalc_mixed = self.mixed_model(hkl, recalc=True) fcalc_dark = self.model_dark(hkl, recalc=True) dF_calc = torch.abs(fcalc_mixed) - torch.abs(fcalc_dark) # Only consider reflections with significant observed difference significant = torch.abs(self._dF_obs) > 0.1 * self._F_obs_dark if significant.sum() < 100: return torch.tensor(0.0, device=hkl.device) dF_obs_sig = self._dF_obs[significant] dF_calc_sig = dF_calc[significant] # The residual between calculated and observed differences # Should be small when both α and structure are correct residual = dF_calc_sig - dF_obs_sig # Normalized by observed amplitude to make it scale-invariant normalized_residual = residual / (torch.abs(dF_obs_sig) + 1e-6) # Penalize large residuals return (normalized_residual ** 2).mean()