Source code for torchref.scaling.collection_scaler

"""
Joint scaler for multi-dataset / multi-model kinetic refinement.

``CollectionScaler`` extends ``ScalerBase`` to operate on paired
``DatasetCollection`` + ``ModelCollection`` instances.  A single set of
scale parameters (log_scale, U, k_sol, B_sol, phase) is shared across
**all** data–model pairs, preventing artificial scale differences that
corrupt the difference signal in time-resolved refinement.

Per-component solvent models are created for each base model in the
``ModelCollection``.  For a mixed model at any timepoint the solvent
contribution is the linear combination of individual component solvent
structure factors, weighted by the same population fractions as the
structural models.
"""

from typing import TYPE_CHECKING, Dict, List, Optional

import torch
import torch.nn as nn

from torchref.base.metrics import nll_xray, get_rfactors
from torchref.config import get_default_device
from torchref.scaling.scaler_base import ScalerBase
from torchref.scaling.solvent import SolventModel
from torchref.utils.utils import ModuleReference

if TYPE_CHECKING:
    from torchref.io.datasets.collection import DatasetCollection
    from torchref.model.model_collection import ModelCollection


[docs] class CollectionScaler(ScalerBase): """ Joint scaler for DatasetCollection + ModelCollection. Shares scale parameters (log_scale, U, bin_wise_bfactor, k_sol, B_sol, phase) across **all** data–model pairs. Manages per-component solvent models so that the bulk-solvent contribution for a mixed model is the fraction-weighted sum of individual component solvent SFs. Parameters ---------- dataset_collection : DatasetCollection Collection of reflection datasets keyed by timepoint name. model_collection : ModelCollection Collection of mixed models keyed by timepoint name. nbins : int Number of resolution bins. verbose : int Verbosity level. device : torch.device Computation device. Examples -------- :: scaler = CollectionScaler(datasets, models, device=device) scaler.initialize() scaler.refine_lbfgs_joint() # In a target: scale a mixed-model F_calc with matching solvent f_scaled = scaler.forward_mixed(f_calc, model.fractions) """
[docs] def __init__( self, dataset_collection: "DatasetCollection", model_collection: "ModelCollection", nbins: int = 20, verbose: int = 1, device: torch.device = get_default_device(), ): # Bind to the dark/reference dataset for bins and scattering vectors dark_data = dataset_collection[model_collection.dark_key] super().__init__( data=dark_data, nbins=nbins, verbose=verbose, device=device, ) self._dataset_collection = dataset_collection self._model_collection = model_collection # Per-component solvent models (one per base model) self._component_solvent_models: nn.ModuleList = nn.ModuleList() # Cached raw solvent SFs per component index self._f_sol_raw_components: Dict[int, torch.Tensor] = {}
# ------------------------------------------------------------------ # Initialization # ------------------------------------------------------------------
[docs] def initialize(self) -> "CollectionScaler": """ One-shot initialization: joint initial scale, component solvents, anisotropy correction. Returns ------- CollectionScaler Self, for method chaining. """ self._calc_initial_scale_joint() self._setup_component_solvent_models() self.setup_anisotropy_correction() return self
def _calc_initial_scale_joint(self): """ Compute initial bin-wise log-scale using ALL data–model pairs. Averages log(F_obs / |F_calc|) per resolution bin across every matched timepoint in the collections. """ dc = self._dataset_collection mc = self._model_collection scales = torch.zeros(self.nbins, device=self.device, dtype=torch.float32) counts = torch.zeros(self.nbins, device=self.device, dtype=torch.float32) all_keys = [mc.dark_key] + mc.timepoint_names n_pairs = 0 for name in all_keys: if name not in dc: continue data = dc[name] model = mc[name] hkl, fobs, sigma, rfree = data(mask=False) with torch.no_grad(): fcalc = model(hkl) fcalc_amp = torch.abs(fcalc).clamp(min=1e-3).to(fobs.dtype) fobs_clamped = fobs.clamp(min=1e-3) # Mask: work set, positive intensities if hasattr(data, "I") and data.I is not None: pos_mask = data.I > 0 else: pos_mask = torch.ones_like(fobs, dtype=torch.bool) mask = (data.masks() & rfree & pos_mask).to(torch.bool) bins = self.bins[mask].to(torch.int64) log_ratios = ( torch.log(fobs_clamped[mask]) - torch.log(fcalc_amp[mask]) ).to(self.device) bins = bins.to(self.device) scales.scatter_add_(0, bins, log_ratios) ones = torch.ones_like(log_ratios) counts.scatter_add_(0, bins, ones) n_pairs += 1 log_scale = scales / (counts + 1e-6) self.log_scale = nn.Parameter(log_scale.detach()) if self.verbose > 0: print( f"Joint initial scale from {n_pairs} data-model pairs " f"({self.nbins} bins)." ) # ------------------------------------------------------------------ # Component solvent models # ------------------------------------------------------------------ def _setup_component_solvent_models(self): """ Create a SolventModel for each base model in the ModelCollection. The **first** component's SolventModel is also stored as ``self.solvent`` — it owns the shared k_sol / B_sol / phase parameters that are optimised during scaling. The remaining component models are used only for their raw solvent structure factors (mask FFT); their own k/B/phase parameters are frozen. """ mc = self._model_collection self._component_solvent_models = nn.ModuleList() self._f_sol_raw_components = {} for i, base_model in enumerate(mc.base_models): sol = SolventModel( base_model, device=self.device, radius=1.1, k_solvent=0.35, b_solvent=46.0, verbose=max(0, self.verbose - 1), ) sol.update_solvent() self._component_solvent_models.append(sol) if i == 0: # Primary solvent model — owns the shared learnable params self.solvent = sol else: # Freeze non-primary solvent params (only mask FFT used) for p in sol.parameters(): p.requires_grad = False # Invalidate any old cache self._f_sol_raw = None self._f_sol_raw_components = {} if self.verbose > 0: print( f" Created {len(self._component_solvent_models)} component " f"solvent models." ) def _get_component_f_sol_raw(self, idx: int) -> torch.Tensor: """ Get (and cache) raw solvent SFs for component *idx*. Parameters ---------- idx : int Index into ``_component_solvent_models``. Returns ------- torch.Tensor Raw (un-damped) complex solvent structure factors. """ if idx not in self._f_sol_raw_components: sol = self._component_solvent_models[idx] self._f_sol_raw_components[idx] = sol.get_rec_solvent(self.hkl) return self._f_sol_raw_components[idx]
[docs] def get_mixed_solvent_raw(self, fractions: torch.Tensor) -> torch.Tensor: """ Compute fraction-weighted raw solvent SFs. ``f_sol_mixed = sum_i(w_i * f_sol_raw_i)`` Parameters ---------- fractions : torch.Tensor Population fractions, shape ``(n_base_models,)``. Returns ------- torch.Tensor Mixed raw solvent structure factors (complex, un-damped). """ f_sol_mixed = None for i in range(len(self._component_solvent_models)): f_raw = self._get_component_f_sol_raw(i) contribution = fractions[i] * f_raw f_sol_mixed = contribution if f_sol_mixed is None else f_sol_mixed + contribution return f_sol_mixed
# ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------
[docs] def forward_mixed( self, fcalc: torch.Tensor, fractions: torch.Tensor, ) -> torch.Tensor: """ Scale *fcalc* using the shared parameters **and** a fraction- weighted solvent contribution. This sets ``_f_sol_raw`` to the mixed solvent and then delegates to ``ScalerBase.forward()``, which applies k_sol / B_sol / phase damping and the overall + anisotropic scale. Parameters ---------- fcalc : torch.Tensor Calculated structure factors (complex). fractions : torch.Tensor Population fractions for the mixed model, shape ``(n_base_models,)``. Returns ------- torch.Tensor Scaled structure factors. """ f_sol_raw_mixed = self.get_mixed_solvent_raw(fractions) return super().forward(fcalc, f_sol_override=f_sol_raw_mixed)
# ------------------------------------------------------------------ # Joint LBFGS refinement # ------------------------------------------------------------------
[docs] def refine_lbfgs_joint( self, nsteps: int = 3, lr: float = 1.0, max_iter: int = 200, history_size: int = 10, verbose: bool = True, ) -> dict: """ Refine scale parameters using LBFGS against **all** datasets. The closure sums the NLL across every matched dataset–model pair, so a single set of scale parameters is fitted jointly. Parameters ---------- nsteps : int Number of LBFGS outer steps. lr : float Learning rate (typically 1.0 for LBFGS). max_iter : int Maximum line-search iterations per step. history_size : int LBFGS history size. verbose : bool Print progress. Returns ------- dict Refinement metrics (steps, rwork, rfree of dark dataset). """ from torchref.refinement.loss_state import LossState dc = self._dataset_collection mc = self._model_collection all_keys = [mc.dark_key] + mc.timepoint_names # Pre-compute all fcalc (detached) fcalc_cache = {} fractions_cache = {} data_cache = {} for name in all_keys: if name not in dc: continue data = dc[name] model = mc[name] hkl, fobs, sigma, rfree = data() with torch.no_grad(): fc = model(hkl) fcalc_cache[name] = fc.detach() fractions_cache[name] = model.fractions.detach() data_cache[name] = (fobs, sigma, rfree) # Wrap the joint NLL + U-penalty as a LossState target. fcalc is # detached, so the only leaves in the autograd graph are the # scaler's own parameters — LossState's probe picks them up at # registration time, and validate_loss inside state.step handles # NaN/Inf rejection so no per-target try/except is needed. scaler_self = self class _CollectionScalerJointTarget(nn.Module): name = "scaler/joint" def forward(self): total = torch.tensor(0.0, device=scaler_self.device) n = 0 for nm in all_keys: if nm not in fcalc_cache: continue fc = fcalc_cache[nm] fracs = fractions_cache[nm] fobs_n, sigma_n, _ = data_cache[nm] f_sol_raw = scaler_self.get_mixed_solvent_raw(fracs) scaled = super(CollectionScaler, scaler_self).forward( fc, f_sol_override=f_sol_raw ) loss = nll_xray(fobs_n, scaled, sigma_n) if torch.isfinite(loss): total = total + loss n += 1 if n > 0: total = total / n u_penalty = torch.sum(scaler_self.U ** 2) return total + u_penalty state = LossState(device=self.device) state.register_target("scaler/joint", _CollectionScalerJointTarget()) optimizer = torch.optim.LBFGS( self.parameters(), lr=lr, max_iter=max_iter, history_size=history_size, line_search_fn="strong_wolfe", ) metrics = { "target": "scales_joint", "steps": [], "rwork": [], "rfree": [], } if verbose and self.verbose > 0: print("Refining scales jointly with LBFGS...") state.run( optimizer, nsteps=nsteps, log=False, context="collection_scaler.refine_lbfgs_joint", ) # Evaluate metrics once on the dark dataset after refinement. with torch.no_grad(): dark_name = mc.dark_key if dark_name in fcalc_cache: fc = fcalc_cache[dark_name] fracs = fractions_cache[dark_name] fobs, sigma, rfree = data_cache[dark_name] f_sol_raw = self.get_mixed_solvent_raw(fracs) scaled = super(CollectionScaler, self).forward( fc, f_sol_override=f_sol_raw ) rwork, rfree_val = get_rfactors( torch.abs(fobs), torch.abs(scaled), rfree ) metrics["steps"].append(nsteps) metrics["rwork"].append(rwork) metrics["rfree"].append(rfree_val) if verbose and self.verbose > 0: if metrics["rwork"]: print( f"Joint scale refinement complete. " f"rwork: {metrics['rwork'][-1]:.4f}, " f"rfree: {metrics['rfree'][-1]:.4f}" ) return metrics
# ------------------------------------------------------------------ # Solvent parameter screening # ------------------------------------------------------------------
[docs] def screen_solvent_params_joint(self, steps: int = 15): """ Grid-search k_sol / B_sol using NLL summed across all datasets. Parameters ---------- steps : int Grid points per parameter. """ if not self._component_solvent_models: raise RuntimeError("No component solvent models. Call initialize() first.") dc = self._dataset_collection mc = self._model_collection all_keys = [mc.dark_key] + mc.timepoint_names # Pre-compute fcalc (detached) pairs = [] for name in all_keys: if name not in dc: continue data = dc[name] model = mc[name] hkl, fobs, sigma, rfree = data() with torch.no_grad(): fc = model(hkl) pairs.append((fc.detach(), model.fractions.detach(), fobs, sigma, rfree)) sol = self.solvent best_log_k = sol.log_k_solvent.clone() best_b = sol.b_solvent.clone() best_loss = float("inf") ksol_start = torch.log(torch.tensor(0.1, device=self.device)) ksol_end = torch.log(torch.tensor(0.6, device=self.device)) for log_k in torch.linspace(ksol_start, ksol_end, steps=steps, device=self.device): for b in torch.linspace(30.0, 100.0, steps=steps, device=self.device): sol.log_k_solvent.data = log_k.to(dtype=sol.log_k_solvent.dtype) sol.b_solvent.data = b.to(dtype=sol.b_solvent.dtype) total = 0.0 for fc, fracs, fobs, sigma, rfree in pairs: f_sol_raw = self.get_mixed_solvent_raw(fracs) scaled = super(CollectionScaler, self).forward( fc, f_sol_override=f_sol_raw ) diff = fobs[rfree] - torch.abs(scaled[rfree]) sigma_safe = sigma[rfree].clamp(min=1e-3) total += (0.5 * (diff ** 2) / sigma_safe ** 2).mean().item() if total < best_loss: best_loss = total best_log_k = log_k.clone() best_b = b.clone() sol.log_k_solvent.data = best_log_k.to(dtype=sol.log_k_solvent.dtype) sol.b_solvent.data = best_b.to(dtype=sol.b_solvent.dtype) if self.verbose > 0: k_val = torch.exp(best_log_k).item() print( f"Joint solvent screening: k_sol={k_val:.4f}, " f"B_sol={best_b.item():.1f}, NLL={best_loss:.4f}" )
# ------------------------------------------------------------------ # Solvent mask updates # ------------------------------------------------------------------
[docs] def update_all_solvent(self): """ Recompute solvent masks for all component models. Call this after structure refinement changes base-model coordinates. """ self._f_sol_raw_components = {} self._f_sol_raw = None for sol in self._component_solvent_models: sol.update_solvent() if self.verbose > 0: print(" Updated all component solvent masks.")
[docs] def invalidate_solvent_cache(self): """Clear cached raw solvent SFs (forces recomputation on next call).""" self._f_sol_raw_components = {} self._f_sol_raw = None
# ------------------------------------------------------------------ # Convenience # ------------------------------------------------------------------ @property def component_solvent_models(self) -> nn.ModuleList: """Per-component SolventModel instances (read-only).""" return self._component_solvent_models def __repr__(self): n_comp = len(self._component_solvent_models) n_ds = self._dataset_collection.n_datasets if self._dataset_collection else 0 return ( f"CollectionScaler({n_comp} component solvents, " f"{n_ds} datasets, {self.nbins} bins)" )