Source code for torchref.refinement.targets.xray.base

from typing import TYPE_CHECKING, Dict, Tuple

import torch

from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from ..base import DataTarget

if TYPE_CHECKING:
    from torchref.io import ReflectionData
    from torchref.model.model import Model
    from torchref.model.model_ft import ModelFT
    from torchref.scaling.scaler_base import Scaler


[docs] class XrayTarget(DataTarget): """ Base class for X-ray targets. Provides common functionality for accessing F_obs, F_calc, etc. Supports two modes of operation: 1. With Model: Computes F_calc from model on each forward pass 2. Without Model: Uses pre-computed F_calc passed to forward()/get_data() Parameters ---------- data : ReflectionData, optional Reference to the ReflectionData object. Required for forward(). model : Model or ModelFT, optional Reference to Model object for F_calc computation. If None, fcalc must be provided to forward(). scaler : Scaler, optional Reference to the Scaler object. use_work_set : bool, optional If True, compute loss on work set; if False, on test set. Default is True. verbose : int, optional Verbosity level. Default is 0. Attributes ---------- use_work_set : bool Whether to use work set or test set. """ name: str = "xray" # Will be overridden based on work/test set
[docs] def __init__( self, data: "ReflectionData" = None, model: "Model" = None, scaler: "Scaler" = None, use_work_set: bool = True, sigma_mode: str = "raw", verbose: int = 0, ): """ Initialize X-ray target. Parameters ---------- data : ReflectionData, optional Reference to the ReflectionData object. Required for forward(). model : Model or ModelFT, optional Reference to Model object for F_calc computation. If None, fcalc must be provided to forward(). scaler : Scaler, optional Reference to the Scaler object. use_work_set : bool, optional If True, compute loss on work set; if False, on test set. Default is True. sigma_mode : str, optional Which sigma to use in the likelihood. Options: - ``'raw'`` (default): use the raw experimental sigmas from the data file. Empirically gives the best Rfree across the mid-resolution regime (1.5-3.0 A) when paired with appropriate group weights. - ``'effective'``: use per-shell effective sigmas estimated from scaling residuals (capped SIGMAA-style correction). Opt-in for high-resolution refinement (< 1.5 A) or datasets with known sigma miscalibration. Note: ``Scaler.estimate_sigma_eff`` is *always* called so the estimates are available regardless of which mode the target uses. verbose : int, optional Verbosity level. Default is 0. """ super().__init__(data=data, model=model, scaler=scaler, verbose=verbose) self.use_work_set = use_work_set if sigma_mode not in ("effective", "raw"): raise ValueError( f"sigma_mode must be 'effective' or 'raw', got {sigma_mode!r}" ) self.sigma_mode = sigma_mode # Set name based on work/test set self.name = "xray_work" if use_work_set else "xray_test" # Cache for the bookkeeping pieces of get_data: F_obs_sel, sigma_sel, # mask, centric_sel. These depend only on the data (rfree, validity, # centric) and the ReflectionData scaling parameters (log_scale, # U_aniso). None of those change during xyz/adp refinement, so we # recompute only when their fingerprint changes (see # ``_data_fingerprint``). self._cached_get_data = None self._cached_get_data_fp = None self._cached_sigma_mode = None
def _data_fingerprint(self): """Fingerprint everything in ``self._data`` that get_data's cached pieces depend on. Used to detect when the bookkeeping cache (F_obs_sel / sigma_sel / mask / centric_sel) must be rebuilt. We probe the (data_ptr, _version) of any param/buffer the data scale or anisotropy correction depends on. During typical xyz/adp refinement these never change, so the cache stays warm across every closure call. """ d = self._data entries = [] for attr in ("log_scale", "U_aniso"): t = getattr(d, attr, None) if isinstance(t, torch.Tensor): entries.append((attr, t.data_ptr(), t._version)) # If sigma_mode flips, the cache also needs to be rebuilt. # Tracked separately in self._cached_sigma_mode. return tuple(entries) def _build_get_data_cache(self): """Recompute the bookkeeping tensors that don't depend on fcalc. Returns ``(F_obs_sel, sigma_sel, mask, centric_sel)``. """ _hkl, F_obs, sigma_F_obs, rfree_mask = self._data() # Sigma selection: use per-shell effective sigma from scaler if requested if self.sigma_mode == "effective" and self._scaler is not None: sigma_eff = getattr(self._scaler, "sigma_eff", None) if sigma_eff is not None and sigma_eff.shape == sigma_F_obs.shape: sigma_F_obs = sigma_eff centric_all = self._data.centric if hasattr(F_obs, "get_mask"): validity_mask = F_obs.get_mask() F_obs_data = F_obs.get_data() sigma_data = ( sigma_F_obs.get_data() if hasattr(sigma_F_obs, "get_mask") else sigma_F_obs ) rfree_bool = rfree_mask.bool() mask = ( validity_mask & rfree_bool if self.use_work_set else validity_mask & ~rfree_bool ) else: F_obs_data = F_obs sigma_data = sigma_F_obs rfree_bool = rfree_mask.bool() mask = rfree_bool if self.use_work_set else ~rfree_bool F_obs_sel = torch.where(mask, F_obs_data, torch.zeros_like(F_obs_data)) sigma_sel = torch.where(mask, sigma_data, torch.ones_like(sigma_data)) centric_sel = (centric_all & mask) if centric_all is not None else None return F_obs_sel, sigma_sel, mask, centric_sel
[docs] def reset_get_data_cache(self): """Drop the cached bookkeeping tensors. Call this if you mutate ``self._data.log_scale`` / ``self._data.U_aniso`` in-place outside of the normal fingerprint- tracked flow, or if you want to free the memory. """ self._cached_get_data = None self._cached_get_data_fp = None self._cached_sigma_mode = None
[docs] def get_data( self, fcalc: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Get F_obs, F_calc, sigma, and centric flags for the appropriate set. Bookkeeping tensors (F_obs_sel, sigma_sel, mask, centric_sel) are cached and reused as long as the upstream scaling parameters (``log_scale``, ``U_aniso``) of the ReflectionData haven't been mutated. Only ``F_calc_sel`` is recomputed from the live fcalc on each call. Parameters ---------- fcalc : torch.Tensor, optional Pre-computed structure factors. If provided, uses these instead of computing from model. Useful when model is not set. Returns ------- tuple ``(F_obs_sel, F_calc_sel, sigma_sel, centric_sel, mask)``. """ fp = self._data_fingerprint() if ( self._cached_get_data is None or self._cached_get_data_fp != fp or self._cached_sigma_mode != self.sigma_mode ): self._cached_get_data = self._build_get_data_cache() self._cached_get_data_fp = fp self._cached_sigma_mode = self.sigma_mode F_obs_sel, sigma_sel, mask, centric_sel = self._cached_get_data # F_calc must always be computed fresh — depends on the live model # state (xyz / adp / occ). if fcalc is not None: F_calc = self.get_F_calc_scaled(fcalc=fcalc) else: F_calc = self.get_F_calc_scaled(self._data.hkl, recalc=False) F_calc_sel = torch.where(mask, F_calc, torch.zeros_like(F_calc)) return F_obs_sel, F_calc_sel, sigma_sel, centric_sel, mask
[docs] def stats(self, fcalc: torch.Tensor = None) -> Dict[str, StatEntry]: """ Get statistics for this X-ray target. Parameters ---------- fcalc : torch.Tensor, optional Pre-computed structure factors. Returns ------- dict Statistics dict with StatEntry values containing verbosity levels. """ F_obs, F_calc, sigma, _, mask = self.get_data(fcalc=fcalc) F_calc_amp = torch.abs(F_calc) diff = F_obs - F_calc_amp loss = self.forward(fcalc=fcalc) rwork, rfree = self.get_rfactor() return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "n": stat(mask.sum().item(), VERBOSITY_DEBUG), "rwork": stat(rwork, VERBOSITY_STANDARD), "rfree": stat(rfree, VERBOSITY_STANDARD), }