Source code for torchref.refinement.targets.realspace

"""
Real-Space Targets for Crystallographic Refinement.

This module provides target (loss) functions that compare electron density
maps in real space rather than reciprocal space. Two targets are provided:

1. RealSpaceCorrelationTarget: Maximizes RSCC between 2mFo-DFc and Fcalc density
2. RealSpaceDifferenceTarget: Minimizes mean squared Fo-Fc difference density

Both targets use a molecular mask (inverse of solvent mask) to restrict
comparison to the protein region, and follow the phase detachment pattern
from PhaseInformedDifferenceTarget to ensure correct gradient flow.
"""

from typing import TYPE_CHECKING, Dict, Optional, Tuple

import torch

from torchref.base.reciprocal.grid_operations import place_on_grid
from torchref.symmetry.grid_utils import calculate_optimal_grid_size
from torchref.symmetry.reciprocal_symmetry import expand_hkl
from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import DataTarget

if TYPE_CHECKING:
    from torchref.io.datasets import ReflectionData, DatasetCollection
    from torchref.model import MixedModel
    from torchref.model.model_ft import ModelFT
    from torchref.scaling.scaler_base import Scaler


[docs] class RealSpaceTarget(DataTarget): """ Base class for real-space electron density targets. Inherits from DataTarget to get model, data, and scaler references. Provides common infrastructure for computing observed maps, model density, and molecular masks used by the concrete subclasses. Gradient Flow Design -------------------- - Model density: gradients flow through Fcalc -> grid -> IFFT -> density - Observed map (2mFo-DFc): phases and |Fcalc| detached, no gradients - Observed map (Fo-Fc): |Fcalc| retains gradients, phases detached - Molecular mask: boolean, no gradients Parameters ---------- data : ReflectionData Observed reflection data. model : ModelFT Model for computing Fcalc. scaler : Scaler, optional Scaler for Fcalc (applied before map coefficient computation). map_type : str ``"2mFo-DFc"`` or ``"Fo-Fc"``. mask_solvent : bool Whether to apply molecular mask. Default True. solvent_radius : float Probe radius for mask dilation in Angstroms. Default 1.1. erosion_radius : float Radius for mask erosion in Angstroms. Default 0.9. verbose : int Verbosity level. Default 0. target_value : float Target value for loss. Default 0.0. sigma : float Sigma for weighting. Default 0.5. """ VALID_MAP_TYPES = ("2mFo-DFc", "Fo-Fc")
[docs] def __init__( self, data: "ReflectionData" = None, model: "ModelFT" = None, scaler: "Scaler" = None, map_type: str = "2mFo-DFc", mask_solvent: bool = True, solvent_radius: float = 1.1, erosion_radius: float = 0.9, verbose: int = 0, target_value: float = 0.0, sigma: float = 0.5, ): super().__init__( data=data, model=model, scaler=scaler, verbose=verbose, target_value=target_value, sigma=sigma, ) if map_type not in self.VALID_MAP_TYPES: raise ValueError( f"map_type must be one of {self.VALID_MAP_TYPES}, got '{map_type}'" ) self.map_type = map_type self._mask_solvent = mask_solvent self._solvent_radius = solvent_radius self._erosion_radius = erosion_radius # Caches (not registered as buffers since they're lazily computed) self._data_p1 = None self._molecular_mask = None self._gridsize = None # P1 expansion cache (ASU → P1 mapping) self._hkl_p1 = None self._p1_indices = None self._p1_phase_shifts = None
def _ensure_grid(self): """Ensure model's SfFFT grid is set up.""" if self._model is None: raise RuntimeError("No model set for RealSpaceTarget") if self._model.real_space_grid is None: self._model.setup_grid() def _get_data_p1(self) -> "ReflectionData": """Return P1-expanded ReflectionData, cached after first call.""" if self._data_p1 is None: self._data_p1 = self._data.expand_to_p1() return self._data_p1 def _ensure_p1_expansion(self): """Compute and cache the ASU → P1 expansion mapping.""" if self._hkl_p1 is not None: return hkl_p1, indices, phase_shifts = expand_hkl( self._data.hkl, self._data.spacegroup or "P1", include_friedel=True, remove_absences=True, device=self._data.hkl.device, ) self._hkl_p1 = hkl_p1 self._p1_indices = indices self._p1_phase_shifts = phase_shifts def _expand_to_p1(self, fcalc: torch.Tensor) -> torch.Tensor: """Expand ASU complex structure factors to P1 using cached mapping.""" self._ensure_p1_expansion() fcalc_p1 = fcalc[self._p1_indices] return fcalc_p1 * torch.exp(1j * self._p1_phase_shifts) def _get_gridsize(self) -> Tuple[int, int, int]: """ Get grid size for map computation. Uses the model's FFT grid size to ensure compatibility with the molecular mask (which is built on the model's grid). """ if self._gridsize is not None: return self._gridsize self._ensure_grid() gs = self._model.fft.gridsize self._gridsize = tuple(int(x) for x in gs) return self._gridsize def _compute_observed_map(self) -> torch.Tensor: """ Compute observed electron density map. For ``"2mFo-DFc"``: ``(2*Fobs - |Fcalc|) * exp(i * phi_calc)`` with both |Fcalc| and phases detached (no gradients on observed side). For ``"Fo-Fc"``: ``(Fobs - |Fcalc|) * exp(i * phi_calc)`` with |Fcalc| retaining gradients and phases detached. Scaling is applied at ASU level before P1 expansion. Returns ------- torch.Tensor 3D real-space density map. """ self._ensure_p1_expansion() # Expand Fobs to P1 using the same index mapping as Fcalc # (amplitudes are invariant under symmetry, no phase shift needed) fobs_p1 = self._data.F[self._p1_indices] # Compute and scale Fcalc at ASU level, then expand to P1 fcalc_asu = self.get_fcalc_scaled() fcalc_p1 = self._expand_to_p1(fcalc_asu) # Detach phases (following PhaseInformedDifferenceTarget pattern) phi_calc = torch.angle(fcalc_p1).detach() if self.map_type == "2mFo-DFc": # Fully detached observed side fcalc_amp = fcalc_p1.abs().detach() coefficients = (2.0 * fobs_p1 - fcalc_amp) * torch.exp(1j * phi_calc) elif self.map_type == "Fo-Fc": # |Fcalc| retains gradients, phases detached fcalc_amp = fcalc_p1.abs() coefficients = (fobs_p1 - fcalc_amp) * torch.exp(1j * phi_calc) else: raise ValueError(f"Unknown map_type: {self.map_type}") gridsize = self._get_gridsize() grid = place_on_grid(self._hkl_p1, coefficients, gridsize, enforce_hermitian=False) return torch.fft.ifftn(grid, dim=(0, 1, 2), norm="forward").real def _compute_model_density(self) -> torch.Tensor: """ Compute model electron density via Fcalc -> grid -> IFFT. Scaling is applied at ASU level before P1 expansion. Retains full autograd graph for gradient flow through model parameters. Returns ------- torch.Tensor 3D real-space model density map. """ self._ensure_p1_expansion() # Compute and scale Fcalc at ASU level, then expand to P1 fcalc_asu = self.get_fcalc_scaled() fcalc_p1 = self._expand_to_p1(fcalc_asu) gridsize = self._get_gridsize() grid = place_on_grid(self._hkl_p1, fcalc_p1, gridsize, enforce_hermitian=False) return torch.fft.ifftn(grid, dim=(0, 1, 2), norm="forward").real def _build_molecular_mask(self): """ Build molecular mask using SolventModel. The molecular mask is the inverse of the solvent mask: True = protein region, False = solvent region. """ from torchref.scaling.solvent import SolventModel self._ensure_grid() with torch.no_grad(): solvent = SolventModel( model=self._model, radius=self._solvent_radius, erosion_radius=self._erosion_radius, optimize_phase=False, verbose=0, ) solvent_mask = solvent.get_solvent_mask() # True = solvent self._molecular_mask = ~solvent_mask # True = protein def _get_molecular_mask(self) -> torch.Tensor: """Get molecular mask, building on first call.""" if self._molecular_mask is None: self._build_molecular_mask() return self._molecular_mask
[docs] def update_mask(self): """Explicitly recompute the molecular mask.""" self._molecular_mask = None self._build_molecular_mask()
[docs] class RealSpaceCorrelationTarget(RealSpaceTarget): """ Real-space correlation coefficient (RSCC) target. Computes RSCC between a 2mFo-DFc observed map and Fcalc model density within the molecular mask. The loss is ``1 - RSCC``. The observed map uses detached model phases and amplitudes, so gradients flow only through the model density side. Parameters ---------- data : ReflectionData Observed reflection data. model : ModelFT Model for computing Fcalc. scaler : Scaler, optional Scaler for Fcalc. mask_solvent : bool Whether to apply molecular mask. Default True. solvent_radius : float Probe radius for mask in Angstroms. Default 1.1. erosion_radius : float Radius for mask erosion in Angstroms. Default 0.9. verbose : int Verbosity level. Default 0. """ name: str = "realspace/correlation"
[docs] def __init__( self, data: "ReflectionData" = None, model: "ModelFT" = None, scaler: "Scaler" = None, mask_solvent: bool = True, solvent_radius: float = 1.1, erosion_radius: float = 0.9, verbose: int = 0, ): super().__init__( data=data, model=model, scaler=scaler, map_type="2mFo-DFc", mask_solvent=mask_solvent, solvent_radius=solvent_radius, erosion_radius=erosion_radius, verbose=verbose, target_value=0.0, sigma=0.5, )
[docs] def forward(self) -> torch.Tensor: """ Compute 1 - RSCC loss. Returns ------- torch.Tensor Scalar loss value (1 - RSCC). """ obs_map = self._compute_observed_map() model_density = self._compute_model_density() if self._mask_solvent: mask = self._get_molecular_mask() obs_vals = obs_map[mask] calc_vals = model_density[mask] else: obs_vals = obs_map.flatten() calc_vals = model_density.flatten() # RSCC = cov(obs, calc) / (std(obs) * std(calc) + eps) obs_centered = obs_vals - obs_vals.mean() calc_centered = calc_vals - calc_vals.mean() eps = 1e-8 cov = (obs_centered * calc_centered).mean() std_obs = torch.sqrt((obs_centered**2).mean() + eps) std_calc = torch.sqrt((calc_centered**2).mean() + eps) rscc = cov / (std_obs * std_calc) return 1.0 - rscc
[docs] def stats(self) -> Dict[str, StatEntry]: """ Get statistics for the correlation target. Returns ------- dict Dictionary with loss, rscc, and n_voxels. """ with torch.no_grad(): loss = self.forward() rscc = 1.0 - loss.item() if self._mask_solvent: mask = self._get_molecular_mask() n_voxels = int(mask.sum().item()) else: n_voxels = int(self._compute_model_density().numel()) return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "rscc": stat(rscc, VERBOSITY_STANDARD), "n_voxels": stat(n_voxels, VERBOSITY_DETAILED), }
[docs] class RealSpaceDifferenceTarget(RealSpaceTarget): """ Real-space Fo-Fc difference density target. Computes the mean squared Fo-Fc difference density within the molecular mask. This penalizes unexplained features in the difference map. The |Fcalc| component retains gradients while phases are detached, providing direct gradient signal for model refinement. Parameters ---------- data : ReflectionData Observed reflection data. model : ModelFT Model for computing Fcalc. scaler : Scaler, optional Scaler for Fcalc. mask_solvent : bool Whether to apply molecular mask. Default True. solvent_radius : float Probe radius for mask in Angstroms. Default 1.1. erosion_radius : float Radius for mask erosion in Angstroms. Default 0.9. verbose : int Verbosity level. Default 0. """ name: str = "realspace/difference"
[docs] def __init__( self, data: "ReflectionData" = None, model: "ModelFT" = None, scaler: "Scaler" = None, mask_solvent: bool = True, solvent_radius: float = 1.1, erosion_radius: float = 0.9, verbose: int = 0, ): super().__init__( data=data, model=model, scaler=scaler, map_type="Fo-Fc", mask_solvent=mask_solvent, solvent_radius=solvent_radius, erosion_radius=erosion_radius, verbose=verbose, target_value=0.0, sigma=0.5, )
[docs] def forward(self) -> torch.Tensor: """ Compute mean squared Fo-Fc difference density. Returns ------- torch.Tensor Scalar loss value (mean squared difference density). """ diff_map = self._compute_observed_map() if self._mask_solvent: mask = self._get_molecular_mask() diff_vals = diff_map[mask] else: diff_vals = diff_map.flatten() return (diff_vals**2).mean()
[docs] def stats(self) -> Dict[str, StatEntry]: """ Get statistics for the difference target. Returns ------- dict Dictionary with loss, rms_diff, mean_abs_diff, peak values, and n_voxels. """ with torch.no_grad(): diff_map = self._compute_observed_map() if self._mask_solvent: mask = self._get_molecular_mask() diff_vals = diff_map[mask] n_voxels = int(mask.sum().item()) else: diff_vals = diff_map.flatten() n_voxels = int(diff_vals.numel()) loss = (diff_vals**2).mean() rms_diff = torch.sqrt(loss) mean_abs_diff = diff_vals.abs().mean() max_pos_peak = diff_vals.max() max_neg_peak = diff_vals.min() return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "rms_diff": stat(rms_diff.item(), VERBOSITY_STANDARD), "mean_abs_diff": stat(mean_abs_diff.item(), VERBOSITY_DETAILED), "max_pos_peak": stat(max_pos_peak.item(), VERBOSITY_DETAILED), "max_neg_peak": stat(max_neg_peak.item(), VERBOSITY_DETAILED), "n_voxels": stat(n_voxels, VERBOSITY_DETAILED), }
[docs] class RealSpaceExtrapolatedTarget(RealSpaceTarget): """ Real-space correlation target using extrapolated pure-light density. Computes the RSCC between an extrapolated pure-light electron density map and the light model's Fcalc density within the molecular mask. The loss is ``1 - RSCC``. The extrapolation combines observed dark/light amplitudes with model-derived phases: F_extra = (F_light * exp(i*phi_mixed) - w_dark * F_dark * exp(i*phi_dark)) / w_light where w_dark, w_light are population fractions from the mixed model. Parameters ---------- dataset_collection : DatasetCollection Collection containing 'dark' and 'light' datasets (aligned HKL). model_dark : ModelFT Dark-state model (for dark phases). model_light : ModelFT Light-state model (gradients flow through this model's density). model_mixed : MixedModel Mixed model (for mixed-state phases and population fractions). scaler_dark : Scaler, optional Scaler for dark Fcalc. scaler_mixed : Scaler, optional Scaler for mixed Fcalc. scaler_light : Scaler, optional Scaler for light model Fcalc (model density side). mask_solvent : bool, optional Whether to apply molecular mask. Default True. solvent_radius : float, optional Probe radius for mask in Angstroms. Default 1.1. erosion_radius : float, optional Radius for mask erosion in Angstroms. Default 0.9. verbose : int, optional Verbosity level. Default 0. """ name: str = "realspace_extrapolated"
[docs] def __init__( self, dataset_collection: "DatasetCollection", model_dark: "ModelFT" = None, model_light: "ModelFT" = None, model_mixed: "MixedModel" = None, scaler_dark: "Scaler" = None, scaler_mixed: "Scaler" = None, scaler_light: "Scaler" = None, mask_solvent: bool = True, solvent_radius: float = 1.1, erosion_radius: float = 0.9, verbose: int = 0, ): if "dark" not in dataset_collection: raise ValueError("DatasetCollection must contain a 'dark' dataset") if "light" not in dataset_collection: raise ValueError("DatasetCollection must contain a 'light' dataset") # Parent uses model_light for mask/model density, light data for P1 grid super().__init__( data=dataset_collection["light"], model=model_light, scaler=scaler_light, map_type="2mFo-DFc", # overridden by _compute_observed_map mask_solvent=mask_solvent, solvent_radius=solvent_radius, erosion_radius=erosion_radius, verbose=verbose, target_value=0.0, sigma=0.5, ) # Additional models and scalers for phase computation self.add_module("_model_dark", model_dark) self.add_module("_model_mixed", model_mixed) self.add_module("_scaler_dark", scaler_dark) self.add_module("_scaler_mixed", scaler_mixed) # Store references to datasets self._data_dark = dataset_collection["dark"] self._data_light = dataset_collection["light"] # Precompute observed data as buffers self._setup_data() # P1 expansion cache self._hkl_p1 = None self._p1_indices = None self._p1_phase_shifts = None
def _setup_data(self): """Extract and store observed data from datasets as buffers.""" hkl, F_light, sigma_light, rfree_light = self._data_light() _, F_dark, sigma_dark, _ = self._data_dark() # Handle MaskedTensor valid_light = valid_dark = None if hasattr(F_light, "get_mask"): valid_light = F_light.get_mask() F_light = F_light.get_data() if hasattr(F_dark, "get_mask"): valid_dark = F_dark.get_mask() F_dark = F_dark.get_data() # Combined validity mask valid_mask = torch.ones_like(F_light, dtype=torch.bool) if valid_light is not None: valid_mask = valid_mask & valid_light if valid_dark is not None: valid_mask = valid_mask & valid_dark # Zero invalid values to prevent NaN propagation F_light = torch.where(valid_mask, F_light, torch.zeros_like(F_light)) F_dark = torch.where(valid_mask, F_dark, torch.zeros_like(F_dark)) self.register_buffer("_hkl", hkl) self.register_buffer("_F_obs_light", F_light) self.register_buffer("_F_obs_dark", F_dark) self.register_buffer("_valid_mask", valid_mask) def _ensure_p1_expansion(self): """Expand HKL to P1, caching the result.""" if self._hkl_p1 is not None: return spacegroup = self._data_light.spacegroup hkl_p1, indices, phase_shifts = expand_hkl( self._hkl, spacegroup, include_friedel=True, remove_absences=True, device=self._hkl.device, ) self._hkl_p1 = hkl_p1 self._p1_indices = indices self._p1_phase_shifts = phase_shifts def _compute_observed_map(self) -> torch.Tensor: """ Compute extrapolated pure-light electron density map. Phases are detached to prevent circular gradients where atoms can minimise loss by rotating model phases rather than matching the true structure. Population fractions retain gradients, providing a strong self-consistency signal for fraction refinement: the fractions that produce an extrapolated map most consistent with the model density. Returns ------- torch.Tensor 3D real-space density map of the extrapolated pure-light state. """ self._ensure_p1_expansion() # Compute model phases at ASU HKL — detached to prevent circular # gradients where atoms minimise loss by rotating phases rather than # matching the true structure. Phases still update each step from # the current model; they just don't contribute to this target's gradient. fcalc_dark = self._model_dark(self._hkl) if self._scaler_dark is not None: fcalc_dark = self._scaler_dark(fcalc_dark) phi_dark = torch.angle(fcalc_dark).detach() fcalc_mixed = self._model_mixed(self._hkl) if self._scaler_mixed is not None: fcalc_mixed = self._scaler_mixed(fcalc_mixed) phi_mixed = torch.angle(fcalc_mixed).detach() # Population fractions — gradients retained for fraction refinement fractions = self._model_mixed.fractions w_dark = fractions[0] w_light = fractions[1] # Phase observed amplitudes and extrapolate F_obs_dark_phased = self._F_obs_dark * torch.exp(1j * phi_dark) F_obs_light_phased = self._F_obs_light * torch.exp(1j * phi_mixed) F_extra = (F_obs_light_phased - w_dark * F_obs_dark_phased) / w_light # Zero out invalid reflections F_extra = torch.where(self._valid_mask, F_extra, torch.zeros_like(F_extra)) # Expand to P1 F_extra_p1 = F_extra[self._p1_indices] # Apply phase shifts from symmetry translations F_extra_p1 = F_extra_p1 * torch.exp(1j * self._p1_phase_shifts) # FFT to real space gridsize = self._get_gridsize() grid = place_on_grid( self._hkl_p1, F_extra_p1, gridsize, enforce_hermitian=False ) return torch.fft.ifftn(grid, dim=(0, 1, 2), norm="forward").real def _compute_model_density(self) -> torch.Tensor: """ Compute model density from light model Fcalc. Uses the P1 expansion computed by this target (not the parent's data_p1 cache) for consistency. Returns ------- torch.Tensor 3D real-space model density map. """ self._ensure_p1_expansion() # Compute light Fcalc at ASU HKL fcalc = self._model(self._hkl) if self._scaler is not None: fcalc = self._scaler(fcalc) # Expand to P1 fcalc_p1 = fcalc[self._p1_indices] fcalc_p1 = fcalc_p1 * torch.exp(1j * self._p1_phase_shifts) gridsize = self._get_gridsize() grid = place_on_grid( self._hkl_p1, fcalc_p1, gridsize, enforce_hermitian=False ) return torch.fft.ifftn(grid, dim=(0, 1, 2), norm="forward").real
[docs] def forward(self) -> torch.Tensor: """ Compute 1 - RSCC between extrapolated map and model density. Returns ------- torch.Tensor Scalar loss value (1 - RSCC). """ obs_map = self._compute_observed_map() model_density = self._compute_model_density() if self._mask_solvent: mask = self._get_molecular_mask() obs_vals = obs_map[mask] calc_vals = model_density[mask] else: obs_vals = obs_map.flatten() calc_vals = model_density.flatten() # RSCC via Pearson correlation obs_centered = obs_vals - obs_vals.mean() calc_centered = calc_vals - calc_vals.mean() eps = 1e-8 cov = (obs_centered * calc_centered).mean() std_obs = torch.sqrt((obs_centered**2).mean() + eps) std_calc = torch.sqrt((calc_centered**2).mean() + eps) rscc = cov / (std_obs * std_calc) return 1.0 - rscc
[docs] def stats(self) -> Dict[str, StatEntry]: """ Get statistics for the extrapolated real-space target. Returns ------- dict Dictionary with loss, rscc, and n_voxels. """ with torch.no_grad(): loss = self.forward() rscc = 1.0 - loss.item() if self._mask_solvent: mask = self._get_molecular_mask() n_voxels = int(mask.sum().item()) else: n_voxels = int(self._compute_model_density().numel()) return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "rscc": stat(rscc, VERBOSITY_STANDARD), "n_voxels": stat(n_voxels, VERBOSITY_DETAILED), }