torchref.kinetic.refinement module

Kinetic refinement orchestrator.

Uses LossState + ModelCollection + DatasetCollection to refine multiple structural models against multiple time-resolved datasets with kinetics- constrained occupancy fractions.

Example

from torchref import ModelFT, ReflectionData, DatasetCollection
from torchref.kinetic import ModelCollection, KineticRefinement

# Base models
model_dark = ModelFT(max_res=1.5).load_pdb("dark.pdb")
model_light = ModelFT(max_res=1.5).load_pdb("light.pdb")

# Collections
models = ModelCollection([model_dark, model_light])
models.add_dark()
models.add_timepoint("1ps", [0.9, 0.1])

datasets = DatasetCollection()
datasets.add_dataset("dark", ReflectionData().load_mtz("dark.mtz"))
datasets.add_dataset("1ps", ReflectionData().load_mtz("1ps.mtz"))

ref = KineticRefinement(datasets, models)
ref.setup()
ref.refine(macro_cycles=5)
class torchref.kinetic.refinement.KineticRefinement(dataset_collection, model_collection, xray_weight_difference=2.0, xray_weight_ml=1.0, geometry_weight=10.0, adp_weight=3.0, kinetic_prior_weight=0.0, device=None, verbose=1)[source]

Bases: DeviceMixin, Module

Orchestrator for kinetic refinement of time-resolved data.

Manages LossState registration, scaler initialization, and optimization loops for alternating structure / fraction refinement.

Parameters:
  • dataset_collection (DatasetCollection) – Collection of reflection datasets keyed by timepoint name.

  • model_collection (ModelCollection) – Collection of mixed models keyed by timepoint name.

  • xray_weight_difference (float) – Weight for the difference X-ray target.

  • xray_weight_ml (float) – Weight for the ML amplitude target.

  • geometry_weight (float) – Weight for geometry restraints.

  • adp_weight (float) – Weight for ADP restraints.

  • kinetic_prior_weight (float) – Weight for kinetic prior regularization (0 to disable).

  • device (torch.device, optional) – Computation device. If None, inferred from model collection.

  • verbose (int) – Verbosity level.

__init__(dataset_collection, model_collection, xray_weight_difference=2.0, xray_weight_ml=1.0, geometry_weight=10.0, adp_weight=3.0, kinetic_prior_weight=0.0, device=None, verbose=1)[source]
setup(cif_paths=None, kinetic_model=None, timepoints_map=None)[source]

One-shot initialization: scalers, restraints, targets, LossState.

Parameters:
  • cif_paths (List[str], optional) – Paths to CIF restraint dictionaries for ligands.

  • kinetic_model (occupancies_kinetics, optional) – Kinetic model for prior regularization. If None, the kinetic prior target is not created.

  • timepoints_map (Dict[str, int], optional) – Maps timepoint names to indices into the kinetic model’s time axis. Required when kinetic_model is provided.

set_weights(**kwargs)[source]

Set target weights by name.

Parameters:

**kwargs – Keyword arguments mapping target paths to weights. E.g. set_weights(geometry=5.0, adp=2.0).

get_loss(log_values=False)[source]

Evaluate all targets and return weighted total loss.

print_loss_summary()[source]

Print breakdown of current losses.

refine(macro_cycles=5, niter=10, max_iter=50)[source]

Full refinement: structures + fractions jointly.

Parameters:
  • macro_cycles (int) – Number of macro-cycles.

  • niter (int) – LBFGS outer iterations per macro-cycle.

  • max_iter (int) – LBFGS inner iterations per step.

refine_structures(niter=10, max_iter=50)[source]

Refine base model structures only (xyz, adp).

Freezes fractions during optimization.

refine_fractions(niter=10, max_iter=50)[source]

Refine per-timepoint fractions only.

Freezes structures during optimization.

refine_alternating(n_cycles=5, niter_structures=10, niter_fractions=5, refit_prior_every=2, max_iter=50)[source]

Alternating refinement: structures → fractions → refit prior.

Parameters:
  • n_cycles (int) – Number of alternating cycles.

  • niter_structures (int) – LBFGS iterations for structure refinement.

  • niter_fractions (int) – LBFGS iterations for fraction refinement.

  • refit_prior_every (int) – Refit kinetic prior every N cycles (0 to disable).

  • max_iter (int) – LBFGS inner iterations per step.

refit_kinetic_prior(niter=50, lr=0.01)[source]

Refit the kinetic model to match current free fractions.

This is the M-step in the EM-style alternation.

refine_kinetics(niter=200, lr=0.01)[source]

Optimize kinetic model parameters directly against the X-ray targets.

Component structure factors (base ModelFT) are frozen. The kinetic model predictions are injected as fraction overrides into each timepoint’s mixed model, so the gradient path is:

kinetic params → occupancies → fractions → F_calc → X-ray loss

After optimization, the free fraction parameters are updated to match the final kinetic predictions.

Parameters:
  • niter (int) – Number of Adam optimizer steps.

  • lr (float) – Learning rate.

write_pdbs(outdir)[source]

Write base model PDBs to outdir.