"""
A class for scaling and post corrections of scattering factors.
Currently implements:
- Overall scale per resolution bin
- B-factor per resolution bin
- Anisotropy correction
- Solvent model correction
This module provides the full-featured `Scaler` class that maintains a reference
to a Model object. For a model-independent scaler, see `ScalerBase`.
"""
from typing import Optional, Tuple, TYPE_CHECKING
import torch
import torch.nn as nn
from torchref.config import get_default_device
from torchref.io import ReflectionData
from torchref.base.metrics import bin_wise_rfactors, get_rfactors, nll_xray, nll_xray_lognormal
from torchref.base.reciprocal import get_scattering_vectors
from torchref.scaling.scaler_base import ScalerBase
from torchref.scaling.solvent import SolventModel
from torchref.utils.device_resolution import resolve_device
from torchref.utils.utils import ModuleReference
if TYPE_CHECKING:
from torchref.model import Model
[docs]
class Scaler(ScalerBase):
"""
Full-featured scaler with Model integration.
Extends ScalerBase by maintaining a reference to a Model object
and providing convenience methods that automatically compute F_calc
when not provided.
Supports two initialization patterns:
1. Empty initialization (for state_dict loading)::
scaler = Scaler() # Creates empty shell
scaler.load_state_dict(torch.load('scaler.pt'))
2. Full initialization with model and data::
scaler = Scaler(model, reflection_data, nbins=20)
scaler.initialize()
Parameters
----------
model : Model, optional
Model object for structure factor calculation.
data : ReflectionData, optional
ReflectionData object with observed data.
nbins : int, default 20
Number of resolution bins.
verbose : int, default 1
Verbosity level.
device : torch.device, default: configured device.current
Computation device.
Attributes
----------
device : torch.device
Current computation device.
nbins : int
Number of resolution bins.
"""
[docs]
def __init__(
self,
model: Optional["Model"] = None,
data: Optional[ReflectionData] = None,
nbins: int = 20,
verbose: int = 1,
device: Optional[torch.device] = None,
):
"""
Initialize Scaler.
If model and data are provided, fully initializes the scaler.
If not provided (empty init), creates a shell ready for load_state_dict().
Parameters
----------
model : Model, optional
Model object for structure factor calculation.
data : ReflectionData, optional
ReflectionData object with observed data.
nbins : int, default 20
Number of resolution bins.
verbose : int, default 1
Verbosity level.
device : torch.device, optional
Computation device. If ``None``, derived from ``model``
then ``data`` (model wins on mismatch); otherwise forces
both onto the explicit device. See
:func:`torchref.utils.resolve_device`.
"""
# Pin model+data onto a single device before super().__init__
# registers buffers from ``data.hkl`` / ``data.cell``.
resolved_device = resolve_device(model, data, device=device)
# Initialize base class with data
super(Scaler, self).__init__(
data=data,
nbins=nbins,
verbose=verbose,
device=resolved_device,
)
# Wrap in ModuleReference to avoid registering the model as a
# submodule (which would leak its state into the scaler's state_dict).
self._model_ref = ModuleReference(model) if model is not None else None
@property
def model(self):
"""Access the model object (not a registered submodule)."""
if self._model_ref is None:
return None
return self._model_ref.module
@model.setter
def model(self, value):
"""Set the model reference.
Note: Uses object.__setattr__ to bypass PyTorch's nn.Module.__setattr__,
which would intercept nn.Module assignments and register them as submodules.
"""
ref = ModuleReference(value) if value is not None else None
object.__setattr__(self, "_model_ref", ref)
[docs]
def set_model_and_data(self, model: "Model", data: ReflectionData):
"""
Set model and data references after empty initialization.
This is useful when loading from state_dict and then needing
to reconnect to model/data objects.
Parameters
----------
model : Model
Model object for structure factor calculation.
data : ReflectionData
ReflectionData object with observed data.
"""
# Set _model_ref directly: nn.Module.__setattr__ intercepts assignments
# of nn.Module instances (like `self.model = model`) and registers them
# as submodules, bypassing the property setter entirely.
self._model_ref = ModuleReference(model) if model is not None else None
# Use parent class method for data
self.set_data(data)
[docs]
def initialize(self, fcalc: torch.Tensor = None):
"""
Initialize scaling parameters.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
self.calc_initial_scale(fcalc)
self.setup_solvent()
self.setup_anisotropy_correction()
return self
[docs]
def compute_fcalc(self) -> torch.Tensor:
"""
Compute F_calc from internal model.
Returns
-------
torch.Tensor
Calculated structure factors.
Raises
------
RuntimeError
If no model is set.
"""
if self.model is None:
raise RuntimeError("No model set and no fcalc provided")
return self.model(self.hkl)
[docs]
def calc_initial_scale(self, fcalc: torch.Tensor = None):
"""
Calculate initial scale factors.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
Returns
-------
torch.nn.Parameter
The log scale parameter for each resolution bin.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().calc_initial_scale(fcalc)
[docs]
def fit_anisotropy(self, fcalc: torch.Tensor = None, nsteps: int = 100):
"""
Fit anisotropic correction.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
nsteps : int, default 100
Number of optimization steps.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().fit_anisotropy(fcalc, nsteps=nsteps)
[docs]
def setup_solvent(self):
"""
Setup solvent model using internal model.
Creates a SolventModel using the internal model reference.
"""
if self.model is None:
raise RuntimeError("Model required for solvent setup")
self.solvent = SolventModel(
self.model,
device=self.device,
radius=1.1,
k_solvent=0.35,
b_solvent=46.0,
verbose=self.verbose,
)
self.solvent.update_solvent()
self._f_sol_raw = None # Invalidate cached raw solvent SFs
[docs]
def fit_all_scales(self, fcalc: torch.Tensor = None):
"""
Fit all scale parameters.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().fit_all_scales(fcalc)
[docs]
def screen_solvent_params(
self,
fcalc: torch.Tensor = None,
steps: int = 15,
use_low_res_weighting: bool = True,
low_res_cutoff: float = 5.0,
fit_on_low_res_only: bool = True,
low_res_limit: float = 3.5,
):
"""
Screen solvent parameters using grid search.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
steps : int, default 15
Number of grid points for each parameter.
use_low_res_weighting : bool, default True
If True, weight low-resolution reflections more heavily.
low_res_cutoff : float, default 5.0
Resolution cutoff for weighting in Angstroms.
fit_on_low_res_only : bool, default True
If True, fit using only low-resolution reflections.
low_res_limit : float, default 3.5
Resolution limit for low-res only fitting in Angstroms.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().screen_solvent_params(
fcalc,
steps=steps,
use_low_res_weighting=use_low_res_weighting,
low_res_cutoff=low_res_cutoff,
fit_on_low_res_only=fit_on_low_res_only,
low_res_limit=low_res_limit,
)
[docs]
def refine_lbfgs(
self,
fcalc: torch.Tensor = None,
nsteps: int = 3,
lr: float = 1.0,
max_iter: int = 200,
history_size: int = 10,
verbose: bool = True,
):
"""
Refine scale parameters using LBFGS optimizer.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
nsteps : int, default 3
Number of LBFGS steps.
lr : float, default 1.0
Learning rate (typically 1.0 for LBFGS).
max_iter : int, default 200
Maximum iterations per line search.
history_size : int, default 10
Number of previous gradients to store for Hessian approximation.
verbose : bool, default True
Print progress information.
Returns
-------
dict
Dictionary with refinement metrics.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().refine_lbfgs(
fcalc,
nsteps=nsteps,
lr=lr,
max_iter=max_iter,
history_size=history_size,
verbose=verbose,
)
[docs]
def rfactor(self, fcalc: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate R-factors.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
Returns
-------
tuple
R-work and R-free values.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().rfactor(fcalc)
[docs]
def bin_wise_rfactor(self, fcalc: torch.Tensor = None):
"""
Calculate bin-wise R-factors.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
Returns
-------
mean_res_per_bin : torch.Tensor
Mean resolution per bin.
rwork_per_bin : torch.Tensor
R-work per bin.
rfree_per_bin : torch.Tensor
R-free per bin.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().bin_wise_rfactor(fcalc)
[docs]
def get_binwise_mean_intensity(self, fcalc: torch.Tensor = None):
"""
Get bin-wise mean intensities.
If fcalc is not provided, computes it from the internal model.
Parameters
----------
fcalc : torch.Tensor, optional
Calculated structure factors. If None, computed from model.
Returns
-------
tuple
Mean observed intensity, mean calculated intensity, and mean resolution per bin.
"""
if fcalc is None:
fcalc = self.compute_fcalc()
return super().get_binwise_mean_intensity(fcalc)
[docs]
def state_dict(self, destination=None, prefix="", keep_vars=False):
"""
Return a dictionary containing the complete state of the Scaler.
This includes:
- All registered buffers and parameters (via parent class)
- Scaler-specific metadata (nbins, etc.)
- Solvent model state (if initialized)
Note: Model and data references are NOT saved (managed separately).
Parameters
----------
destination : dict, optional
Optional dict to populate.
prefix : str, default ''
Prefix for parameter names.
keep_vars : bool, default False
Whether to keep variables in computational graph.
Returns
-------
dict
Complete state dictionary.
"""
# Use parent class implementation
return super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
[docs]
def load_state_dict(self, state_dict, strict=True):
"""
Load the Scaler state from a dictionary.
Note: This assumes model and data are already set via __init__ or assignment.
Parameters
----------
state_dict : dict
Dictionary containing scaler state.
strict : bool, default True
Whether to strictly enforce that keys match.
"""
# Extract and load solvent model state if it exists
solvent_state = state_dict.get("solvent", None)
# If solvent state exists but module doesn't, instantiate it
if solvent_state is not None and not hasattr(self, "solvent"):
# We need to instantiate SolventModel.
# It requires: model, radius, k_solvent, b_solvent, etc.
if hasattr(self, "model") and self.model is not None:
self.solvent = SolventModel(
model=self.model, device=self.device, verbose=self.verbose
)
# Use parent class implementation (handles removing 'solvent' from state_dict)
return super().load_state_dict(state_dict, strict=strict)