Source code for torchref.kinetic.refinement

"""
Kinetic refinement orchestrator.

Uses LossState + ModelCollection + DatasetCollection to refine multiple
structural models against multiple time-resolved datasets with kinetics-
constrained occupancy fractions.

Example
-------
::

    from torchref import ModelFT, ReflectionData, DatasetCollection
    from torchref.kinetic import ModelCollection, KineticRefinement

    # Base models
    model_dark = ModelFT(max_res=1.5).load_pdb("dark.pdb")
    model_light = ModelFT(max_res=1.5).load_pdb("light.pdb")

    # Collections
    models = ModelCollection([model_dark, model_light])
    models.add_dark()
    models.add_timepoint("1ps", [0.9, 0.1])

    datasets = DatasetCollection()
    datasets.add_dataset("dark", ReflectionData().load_mtz("dark.mtz"))
    datasets.add_dataset("1ps", ReflectionData().load_mtz("1ps.mtz"))

    ref = KineticRefinement(datasets, models)
    ref.setup()
    ref.refine(macro_cycles=5)
"""

from typing import TYPE_CHECKING, Dict, List, Optional

import torch
from torch import nn

from torchref.refinement.loss_state import LossState, create_loss_state
from torchref.kinetic.targets import (
    CollectionDifferenceTarget,
    CollectionMLTarget,
    MultiModelGeometryTarget,
    MultiModelADPTarget,
    KineticPriorTarget,
)
from torchref.utils.device_mixin import DeviceMixin

if TYPE_CHECKING:
    from torchref.io.datasets.collection import DatasetCollection
    from torchref.model.model_collection import ModelCollection
    from torchref.scaling import Scaler


[docs] class KineticRefinement(DeviceMixin, nn.Module): """ Orchestrator for kinetic refinement of time-resolved data. Manages LossState registration, scaler initialization, and optimization loops for alternating structure / fraction refinement. Parameters ---------- dataset_collection : DatasetCollection Collection of reflection datasets keyed by timepoint name. model_collection : ModelCollection Collection of mixed models keyed by timepoint name. xray_weight_difference : float Weight for the difference X-ray target. xray_weight_ml : float Weight for the ML amplitude target. geometry_weight : float Weight for geometry restraints. adp_weight : float Weight for ADP restraints. kinetic_prior_weight : float Weight for kinetic prior regularization (0 to disable). device : torch.device, optional Computation device. If None, inferred from model collection. verbose : int Verbosity level. """
[docs] def __init__( self, dataset_collection: "DatasetCollection", model_collection: "ModelCollection", xray_weight_difference: float = 2.0, xray_weight_ml: float = 1.0, geometry_weight: float = 10.0, adp_weight: float = 3.0, kinetic_prior_weight: float = 0.0, device: Optional[torch.device] = None, verbose: int = 1, ): super().__init__() self.dataset_collection = dataset_collection self.model_collection = model_collection self.verbose = verbose if device is None: device = model_collection.device self._device = device # Default weights self._weights = { "xray/difference": xray_weight_difference, "xray/ml": xray_weight_ml, "geometry": geometry_weight, "adp": adp_weight, "kinetic_prior": kinetic_prior_weight, } # Placeholders (populated by setup()) self.scaler: Optional["Scaler"] = None self.loss_state: Optional[LossState] = None self.kinetic_prior_target: Optional[KineticPriorTarget] = None self._diff_target: Optional[CollectionDifferenceTarget] = None self._ml_target: Optional[CollectionMLTarget] = None self._kinetic_model = None self._timepoints_map: Optional[Dict[str, int]] = None
# ------------------------------------------------------------------ # Setup # ------------------------------------------------------------------
[docs] def setup( self, cif_paths: Optional[List[str]] = None, kinetic_model=None, timepoints_map: Optional[Dict[str, int]] = None, ): """ One-shot initialization: scalers, restraints, targets, LossState. Parameters ---------- cif_paths : List[str], optional Paths to CIF restraint dictionaries for ligands. kinetic_model : occupancies_kinetics, optional Kinetic model for prior regularization. If None, the kinetic prior target is not created. timepoints_map : Dict[str, int], optional Maps timepoint names to indices into the kinetic model's time axis. Required when *kinetic_model* is provided. """ dc = self.dataset_collection mc = self.model_collection device = self._device # ---- CIF restraints on base models ---- if cif_paths: for model in mc.base_models: if hasattr(model, "set_restraints_cif"): model.set_restraints_cif(cif_paths) # ---- Scalers ---- self._setup_scalers() # ---- Targets ---- diff_target = CollectionDifferenceTarget( dc, mc, scaler=self.scaler, verbose=self.verbose, ) ml_target = CollectionMLTarget( dc, mc, scaler=self.scaler, verbose=self.verbose, ) # Store direct references for refine_kinetics() self._diff_target = diff_target self._ml_target = ml_target geom_target = MultiModelGeometryTarget(mc, verbose=self.verbose) adp_target = MultiModelADPTarget(mc, verbose=self.verbose) # ---- LossState ---- self.loss_state = create_loss_state(device=device) self.loss_state.register_target("xray/difference", diff_target) self.loss_state.register_target("xray/ml", ml_target) self.loss_state.register_target("geometry", geom_target) self.loss_state.register_target("adp", adp_target) # ---- Kinetic prior (optional) ---- if kinetic_model is not None and timepoints_map is not None: self._kinetic_model = kinetic_model self._timepoints_map = timepoints_map self.kinetic_prior_target = KineticPriorTarget( mc, kinetic_model, timepoints_map, verbose=self.verbose ) self.loss_state.register_target( "kinetic_prior", self.kinetic_prior_target ) # ---- Set weights ---- self.loss_state.set_weights(self._weights) if self.verbose > 0: print("KineticRefinement setup complete.") print(f" Targets: {list(self.loss_state.targets.keys())}") print(f" Weights: {self._weights}")
def _setup_scalers(self): """Initialize a joint scaler for all datasets. Uses ``CollectionScaler`` which shares scale parameters (log_scale, U, k_sol, B_sol) across **all** data–model pairs and creates per-component solvent models for each base structure. The solvent contribution for a mixed model is the fraction-weighted sum of individual component solvent SFs. """ from torchref.scaling.collection_scaler import CollectionScaler self.scaler = CollectionScaler( dataset_collection=self.dataset_collection, model_collection=self.model_collection, device=self._device, verbose=max(0, self.verbose - 1), ) self.scaler.initialize() self.scaler.refine_lbfgs_joint() if self.verbose > 0: print(" Joint scaler initialized and refined (all datasets + component solvents).") # ------------------------------------------------------------------ # Weight management # ------------------------------------------------------------------
[docs] def set_weights(self, **kwargs): """ Set target weights by name. Parameters ---------- **kwargs Keyword arguments mapping target paths to weights. E.g. ``set_weights(geometry=5.0, adp=2.0)``. """ # Map short names to full paths mapping = { "difference": "xray/difference", "ml": "xray/ml", "geometry": "geometry", "adp": "adp", "kinetic_prior": "kinetic_prior", } for key, val in kwargs.items(): full_key = mapping.get(key, key) self._weights[full_key] = val if self.loss_state is not None: self.loss_state.set_weight(full_key, val)
# ------------------------------------------------------------------ # Loss computation # ------------------------------------------------------------------
[docs] def get_loss(self, log_values: bool = False) -> torch.Tensor: """Evaluate all targets and return weighted total loss.""" return self.loss_state.aggregate(log_values=log_values)
[docs] def print_loss_summary(self): """Print breakdown of current losses.""" self.loss_state.aggregate(log_values=True) self.loss_state.summary()
# ------------------------------------------------------------------ # Optimization loops # ------------------------------------------------------------------ def _collect_parameters(self, structures=True, fractions=True): """Collect parameters for optimization.""" params = [] mc = self.model_collection if structures: for model in mc.base_models: params.extend( p for p in model.parameters() if p.requires_grad ) if fractions: for name in mc.timepoint_names: p = mc[name].fraction_params if p.requires_grad: params.append(p) # Scaler parameters (if not frozen) if self.scaler is not None: params.extend( p for p in self.scaler.parameters() if p.requires_grad ) return params def _lbfgs_loop(self, params, niter=10, max_iter=50, lr=1.0): """Run L-BFGS optimization loop via :meth:`LossState.run`. ``LossState.run`` handles the closure, NaN validation, and automatically disables ``requires_grad`` on loss-relevant leaves outside ``params`` for the duration of the run. """ if not params: return optimizer = torch.optim.LBFGS( params, lr=lr, max_iter=max_iter, line_search_fn="strong_wolfe" ) self.loss_state.run( optimizer, nsteps=niter, log=False, context="kinetic.refinement._lbfgs_loop", ) if self.verbose > 0: final_loss = self.get_loss(log_values=True) print(f" Optimization complete: loss = {final_loss.item():.6f}")
[docs] def refine(self, macro_cycles: int = 5, niter: int = 10, max_iter: int = 50): """ Full refinement: structures + fractions jointly. Parameters ---------- macro_cycles : int Number of macro-cycles. niter : int LBFGS outer iterations per macro-cycle. max_iter : int LBFGS inner iterations per step. """ for cycle in range(macro_cycles): if self.verbose > 0: print(f"\n--- Macro-cycle {cycle+1}/{macro_cycles} ---") params = self._collect_parameters(structures=True, fractions=True) self._lbfgs_loop(params, niter=niter, max_iter=max_iter) # Update solvent masks after structure changes if hasattr(self.scaler, "update_all_solvent"): self.scaler.update_all_solvent()
[docs] def refine_structures(self, niter: int = 10, max_iter: int = 50): """ Refine base model structures only (xyz, adp). Freezes fractions during optimization. """ mc = self.model_collection mc.freeze_all_fractions() if self.verbose > 0: print("\n--- Refining structures (fractions frozen) ---") params = self._collect_parameters(structures=True, fractions=False) self._lbfgs_loop(params, niter=niter, max_iter=max_iter) mc.unfreeze_all_fractions()
[docs] def refine_fractions(self, niter: int = 10, max_iter: int = 50): """ Refine per-timepoint fractions only. Freezes structures during optimization. """ mc = self.model_collection mc.freeze_structures() if self.verbose > 0: print("\n--- Refining fractions (structures frozen) ---") params = self._collect_parameters(structures=False, fractions=True) self._lbfgs_loop(params, niter=niter, max_iter=max_iter) mc.unfreeze_structures()
[docs] def refine_alternating( self, n_cycles: int = 5, niter_structures: int = 10, niter_fractions: int = 5, refit_prior_every: int = 2, max_iter: int = 50, ): """ Alternating refinement: structures → fractions → refit prior. Parameters ---------- n_cycles : int Number of alternating cycles. niter_structures : int LBFGS iterations for structure refinement. niter_fractions : int LBFGS iterations for fraction refinement. refit_prior_every : int Refit kinetic prior every N cycles (0 to disable). max_iter : int LBFGS inner iterations per step. """ for cycle in range(n_cycles): if self.verbose > 0: print(f"\n{'='*50}") print(f"Alternating cycle {cycle+1}/{n_cycles}") print(f"{'='*50}") # 1. Refine structures self.refine_structures(niter=niter_structures, max_iter=max_iter) # Update solvent masks after structure changes if hasattr(self.scaler, "update_all_solvent"): self.scaler.update_all_solvent() # 2. Refine fractions self.refine_fractions(niter=niter_fractions, max_iter=max_iter) # 3. Refit kinetic prior (if applicable) if ( refit_prior_every > 0 and self.kinetic_prior_target is not None and (cycle + 1) % refit_prior_every == 0 ): if self.verbose > 0: print("\n--- Refitting kinetic prior ---") self.refit_kinetic_prior() # Print fractions if self.verbose > 0: fracs = self.model_collection.get_all_fractions() for name, f in fracs.items(): frac_str = ", ".join(f"{v:.3f}" for v in f.detach().tolist()) print(f" {name}: [{frac_str}]")
[docs] def refit_kinetic_prior(self, niter: int = 50, lr: float = 1e-2): """ Refit the kinetic model to match current free fractions. This is the M-step in the EM-style alternation. """ if self.kinetic_prior_target is not None: self.kinetic_prior_target.refit_prior(niter=niter, lr=lr)
[docs] def refine_kinetics(self, niter: int = 200, lr: float = 1e-2): """ Optimize kinetic model parameters directly against the X-ray targets. Component structure factors (base ModelFT) are frozen. The kinetic model predictions are injected as fraction overrides into each timepoint's mixed model, so the gradient path is:: kinetic params → occupancies → fractions → F_calc → X-ray loss After optimization, the free fraction parameters are updated to match the final kinetic predictions. Parameters ---------- niter : int Number of Adam optimizer steps. lr : float Learning rate. """ if self._kinetic_model is None: raise RuntimeError("No kinetic model configured. Call setup() with kinetic_model first.") kinetic_model = self._kinetic_model timepoints_map = self._timepoints_map mc = self.model_collection # Freeze component structure factors mc.freeze_structures() # Collect only kinetic model parameters params = [p for p in kinetic_model.parameters() if p.requires_grad] if not params: if self.verbose > 0: print(" No refinable kinetic parameters found.") mc.unfreeze_structures() return optimizer = torch.optim.Adam(params, lr=lr) w_diff = self._weights.get("xray/difference", 1.0) w_ml = self._weights.get("xray/ml", 1.0) # Collect all model keys that have kinetic indices (including dark) all_overrides = {} for tp_name in mc.keys(): if tp_name in timepoints_map: all_overrides[tp_name] = timepoints_map[tp_name] for step in range(niter): optimizer.zero_grad() # Kinetic model predictions: [n_states, n_timepoints] kinetic_occ = kinetic_model() # Set fraction overrides on ALL models (including dark) for tp_name, t_idx in all_overrides.items(): mc[tp_name].set_fraction_override(kinetic_occ[:, t_idx]) # Compute X-ray loss only (difference + ML) loss = w_diff * self._diff_target() + w_ml * self._ml_target() loss.backward() # Clear overrides after backward pass for tp_name in all_overrides: mc[tp_name].clear_fraction_override() optimizer.step() if self.verbose > 1 and (step + 1) % 10 == 0: print(f" Kinetic opt step {step+1}/{niter}: loss = {loss.item():.6f}") # Update free fraction parameters to match final kinetic predictions with torch.no_grad(): kinetic_occ = kinetic_model() for tp_name, t_idx in all_overrides.items(): if tp_name == mc.dark_key: continue # dark fractions stay frozen at [1,0,...,0] predicted = kinetic_occ[:, t_idx] mc[tp_name].fraction_params.data = torch.log( predicted.clamp(min=1e-6) ) mc.unfreeze_structures() if self.verbose > 0: print(f" Kinetic refinement: {niter} steps, final X-ray loss = {loss.item():.6f}")
# ------------------------------------------------------------------ # Output # ------------------------------------------------------------------
[docs] def write_pdbs(self, outdir: str): """Write base model PDBs to *outdir*.""" self.model_collection.write_pdbs(outdir)
def __repr__(self): n_tp = len(self.model_collection) n_ds = self.dataset_collection.n_datasets n_base = self.model_collection.n_base_models return ( f"KineticRefinement({n_base} base models, " f"{n_tp} timepoints, {n_ds} datasets)" )