Source code for torchref.scaling.scaler

"""
A class for scaling and post corrections of scattering factors.

Currently implements:
- Overall scale per resolution bin
- B-factor per resolution bin
- Anisotropy correction
- Solvent model correction

This module provides the full-featured `Scaler` class that maintains a reference
to a Model object. For a model-independent scaler, see `ScalerBase`.
"""

from typing import Optional, Tuple, TYPE_CHECKING

import torch
import torch.nn as nn

from torchref.config import get_default_device
from torchref.io import ReflectionData
from torchref.base.metrics import bin_wise_rfactors, get_rfactors, nll_xray, nll_xray_lognormal
from torchref.base.reciprocal import get_scattering_vectors
from torchref.scaling.scaler_base import ScalerBase
from torchref.scaling.solvent import SolventModel
from torchref.utils.device_resolution import resolve_device
from torchref.utils.utils import ModuleReference

if TYPE_CHECKING:
    from torchref.model import Model


[docs] class Scaler(ScalerBase): """ Full-featured scaler with Model integration. Extends ScalerBase by maintaining a reference to a Model object and providing convenience methods that automatically compute F_calc when not provided. Supports two initialization patterns: 1. Empty initialization (for state_dict loading):: scaler = Scaler() # Creates empty shell scaler.load_state_dict(torch.load('scaler.pt')) 2. Full initialization with model and data:: scaler = Scaler(model, reflection_data, nbins=20) scaler.initialize() Parameters ---------- model : Model, optional Model object for structure factor calculation. data : ReflectionData, optional ReflectionData object with observed data. nbins : int, default 20 Number of resolution bins. verbose : int, default 1 Verbosity level. device : torch.device, default: configured device.current Computation device. Attributes ---------- device : torch.device Current computation device. nbins : int Number of resolution bins. """
[docs] def __init__( self, model: Optional["Model"] = None, data: Optional[ReflectionData] = None, nbins: int = 20, verbose: int = 1, device: Optional[torch.device] = None, ): """ Initialize Scaler. If model and data are provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict(). Parameters ---------- model : Model, optional Model object for structure factor calculation. data : ReflectionData, optional ReflectionData object with observed data. nbins : int, default 20 Number of resolution bins. verbose : int, default 1 Verbosity level. device : torch.device, optional Computation device. If ``None``, derived from ``model`` then ``data`` (model wins on mismatch); otherwise forces both onto the explicit device. See :func:`torchref.utils.resolve_device`. """ # Pin model+data onto a single device before super().__init__ # registers buffers from ``data.hkl`` / ``data.cell``. resolved_device = resolve_device(model, data, device=device) # Initialize base class with data super(Scaler, self).__init__( data=data, nbins=nbins, verbose=verbose, device=resolved_device, ) # Wrap in ModuleReference to avoid registering the model as a # submodule (which would leak its state into the scaler's state_dict). self._model_ref = ModuleReference(model) if model is not None else None
@property def model(self): """Access the model object (not a registered submodule).""" if self._model_ref is None: return None return self._model_ref.module @model.setter def model(self, value): """Set the model reference. Note: Uses object.__setattr__ to bypass PyTorch's nn.Module.__setattr__, which would intercept nn.Module assignments and register them as submodules. """ ref = ModuleReference(value) if value is not None else None object.__setattr__(self, "_model_ref", ref)
[docs] def set_model_and_data(self, model: "Model", data: ReflectionData): """ Set model and data references after empty initialization. This is useful when loading from state_dict and then needing to reconnect to model/data objects. Parameters ---------- model : Model Model object for structure factor calculation. data : ReflectionData ReflectionData object with observed data. """ # Set _model_ref directly: nn.Module.__setattr__ intercepts assignments # of nn.Module instances (like `self.model = model`) and registers them # as submodules, bypassing the property setter entirely. self._model_ref = ModuleReference(model) if model is not None else None # Use parent class method for data self.set_data(data)
[docs] def initialize(self, fcalc: torch.Tensor = None): """ Initialize scaling parameters. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. """ if fcalc is None: fcalc = self.compute_fcalc() self.calc_initial_scale(fcalc) self.setup_solvent() self.setup_anisotropy_correction() return self
[docs] def compute_fcalc(self) -> torch.Tensor: """ Compute F_calc from internal model. Returns ------- torch.Tensor Calculated structure factors. Raises ------ RuntimeError If no model is set. """ if self.model is None: raise RuntimeError("No model set and no fcalc provided") return self.model(self.hkl)
[docs] def calc_initial_scale(self, fcalc: torch.Tensor = None): """ Calculate initial scale factors. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. Returns ------- torch.nn.Parameter The log scale parameter for each resolution bin. """ if fcalc is None: fcalc = self.compute_fcalc() return super().calc_initial_scale(fcalc)
[docs] def fit_anisotropy(self, fcalc: torch.Tensor = None, nsteps: int = 100): """ Fit anisotropic correction. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. nsteps : int, default 100 Number of optimization steps. """ if fcalc is None: fcalc = self.compute_fcalc() return super().fit_anisotropy(fcalc, nsteps=nsteps)
[docs] def setup_solvent(self): """ Setup solvent model using internal model. Creates a SolventModel using the internal model reference. """ if self.model is None: raise RuntimeError("Model required for solvent setup") self.solvent = SolventModel( self.model, device=self.device, radius=1.1, k_solvent=0.35, b_solvent=46.0, verbose=self.verbose, ) self.solvent.update_solvent() self._f_sol_raw = None # Invalidate cached raw solvent SFs
[docs] def fit_all_scales(self, fcalc: torch.Tensor = None): """ Fit all scale parameters. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. """ if fcalc is None: fcalc = self.compute_fcalc() return super().fit_all_scales(fcalc)
[docs] def screen_solvent_params( self, fcalc: torch.Tensor = None, steps: int = 15, use_low_res_weighting: bool = True, low_res_cutoff: float = 5.0, fit_on_low_res_only: bool = True, low_res_limit: float = 3.5, ): """ Screen solvent parameters using grid search. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. steps : int, default 15 Number of grid points for each parameter. use_low_res_weighting : bool, default True If True, weight low-resolution reflections more heavily. low_res_cutoff : float, default 5.0 Resolution cutoff for weighting in Angstroms. fit_on_low_res_only : bool, default True If True, fit using only low-resolution reflections. low_res_limit : float, default 3.5 Resolution limit for low-res only fitting in Angstroms. """ if fcalc is None: fcalc = self.compute_fcalc() return super().screen_solvent_params( fcalc, steps=steps, use_low_res_weighting=use_low_res_weighting, low_res_cutoff=low_res_cutoff, fit_on_low_res_only=fit_on_low_res_only, low_res_limit=low_res_limit, )
[docs] def refine_lbfgs( self, fcalc: torch.Tensor = None, nsteps: int = 3, lr: float = 1.0, max_iter: int = 200, history_size: int = 10, verbose: bool = True, ): """ Refine scale parameters using LBFGS optimizer. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. nsteps : int, default 3 Number of LBFGS steps. lr : float, default 1.0 Learning rate (typically 1.0 for LBFGS). max_iter : int, default 200 Maximum iterations per line search. history_size : int, default 10 Number of previous gradients to store for Hessian approximation. verbose : bool, default True Print progress information. Returns ------- dict Dictionary with refinement metrics. """ if fcalc is None: fcalc = self.compute_fcalc() return super().refine_lbfgs( fcalc, nsteps=nsteps, lr=lr, max_iter=max_iter, history_size=history_size, verbose=verbose, )
[docs] def rfactor(self, fcalc: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate R-factors. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. Returns ------- tuple R-work and R-free values. """ if fcalc is None: fcalc = self.compute_fcalc() return super().rfactor(fcalc)
[docs] def bin_wise_rfactor(self, fcalc: torch.Tensor = None): """ Calculate bin-wise R-factors. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. Returns ------- mean_res_per_bin : torch.Tensor Mean resolution per bin. rwork_per_bin : torch.Tensor R-work per bin. rfree_per_bin : torch.Tensor R-free per bin. """ if fcalc is None: fcalc = self.compute_fcalc() return super().bin_wise_rfactor(fcalc)
[docs] def get_binwise_mean_intensity(self, fcalc: torch.Tensor = None): """ Get bin-wise mean intensities. If fcalc is not provided, computes it from the internal model. Parameters ---------- fcalc : torch.Tensor, optional Calculated structure factors. If None, computed from model. Returns ------- tuple Mean observed intensity, mean calculated intensity, and mean resolution per bin. """ if fcalc is None: fcalc = self.compute_fcalc() return super().get_binwise_mean_intensity(fcalc)
[docs] def state_dict(self, destination=None, prefix="", keep_vars=False): """ Return a dictionary containing the complete state of the Scaler. This includes: - All registered buffers and parameters (via parent class) - Scaler-specific metadata (nbins, etc.) - Solvent model state (if initialized) Note: Model and data references are NOT saved (managed separately). Parameters ---------- destination : dict, optional Optional dict to populate. prefix : str, default '' Prefix for parameter names. keep_vars : bool, default False Whether to keep variables in computational graph. Returns ------- dict Complete state dictionary. """ # Use parent class implementation return super().state_dict( destination=destination, prefix=prefix, keep_vars=keep_vars )
[docs] def load_state_dict(self, state_dict, strict=True): """ Load the Scaler state from a dictionary. Note: This assumes model and data are already set via __init__ or assignment. Parameters ---------- state_dict : dict Dictionary containing scaler state. strict : bool, default True Whether to strictly enforce that keys match. """ # Extract and load solvent model state if it exists solvent_state = state_dict.get("solvent", None) # If solvent state exists but module doesn't, instantiate it if solvent_state is not None and not hasattr(self, "solvent"): # We need to instantiate SolventModel. # It requires: model, radius, k_solvent, b_solvent, etc. if hasattr(self, "model") and self.model is not None: self.solvent = SolventModel( model=self.model, device=self.device, verbose=self.verbose ) # Use parent class implementation (handles removing 'solvent' from state_dict) return super().load_state_dict(state_dict, strict=strict)