"""
Model collection for time-resolved kinetic refinement.
Provides ModelCollection — a named dictionary of MixedModel instances at
different timepoints that share the same base structural models (ModelFT).
Keys match DatasetCollection keys so targets can automatically pair them.
"""
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
import torch
from torch import nn
from torchref.config import get_default_device
from torchref.utils.device_mixin import DeviceMovementMixin
if TYPE_CHECKING:
from torchref.model.model_ft import ModelFT
from torchref.model.mixed_model import MixedModel
class _SharedMixedModel(DeviceMovementMixin, nn.Module):
"""
MixedModel variant that references shared base models without re-registering them.
Standard MixedModel wraps models in nn.ModuleList, which causes
double-registration when the same ModelFT objects appear in multiple
timepoints. This class stores the shared models as a plain list
(no ownership) and only owns its own fraction parameters.
Parameters
----------
base_models : List[ModelFT]
Shared structural models (not re-registered as submodules here).
initial_fractions : List[float]
Initial population fractions (must sum to 1).
frozen_fractions : bool
If True, fractions are excluded from optimization.
device : torch.device, optional
Device for fraction parameters.
"""
def __init__(
self,
base_models: List["ModelFT"],
initial_fractions: List[float],
frozen_fractions: bool = False,
device: Optional[torch.device] = None,
):
super().__init__()
# Store as plain list — the parent ModelCollection owns the ModuleList
self._base_models = base_models
n = len(base_models)
if len(initial_fractions) != n:
raise ValueError(
f"Number of fractions ({len(initial_fractions)}) must match "
f"number of models ({n})."
)
total = sum(initial_fractions)
if abs(total - 1.0) > 1e-3:
raise ValueError(f"Initial fractions must sum to 1.0, got {total:.6f}.")
# Normalize to handle floating point drift
initial_fractions = [f / total for f in initial_fractions]
if device is None:
device = base_models[0].device if hasattr(base_models[0], "device") else get_default_device()
fractions_tensor = torch.tensor(initial_fractions, dtype=torch.float32, device=device)
theta = torch.log(fractions_tensor.clamp(min=1e-6))
self.fraction_params = nn.Parameter(theta, requires_grad=not frozen_fractions)
# Optional override: when set, fractions property returns this tensor
# instead of softmax(fraction_params). Used by refine_kinetics() to
# route kinetic model predictions directly into the F_calc computation.
self._fraction_override: Optional[torch.Tensor] = None
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def fractions(self) -> torch.Tensor:
"""Normalized population fractions (sum to 1).
When a fraction override is active (set by ``set_fraction_override``),
returns the override tensor instead of softmax(fraction_params).
This allows kinetic model predictions to flow directly into F_calc.
"""
if self._fraction_override is not None:
return self._fraction_override
return torch.softmax(self.fraction_params, dim=0)
@property
def models(self) -> List["ModelFT"]:
"""The shared base models (read-only reference)."""
return self._base_models
@property
def cell(self):
return self._base_models[0].cell
@property
def spacegroup(self):
return self._base_models[0].spacegroup
@property
def device(self):
return self._base_models[0].device
@property
def dtype_float(self):
return self._base_models[0].dtype_float
@property
def real_space_grid(self):
return self._base_models[0].real_space_grid
@property
def fft(self):
return self._base_models[0].fft
@property
def gridsize(self):
return self._base_models[0].gridsize
@property
def map_symmetry(self):
return self._base_models[0].map_symmetry
@property
def inv_fractional_matrix(self):
return self.cell.inv_fractional_matrix.to(dtype=self.dtype_float)
@property
def fractional_matrix(self):
return self.cell.fractional_matrix.to(dtype=self.dtype_float)
# ------------------------------------------------------------------
# Grid / density helpers (delegate to base models)
# ------------------------------------------------------------------
def setup_grid(self, max_res=None, gridsize=None):
for model in self._base_models:
model.setup_grid(max_res=max_res, gridsize=gridsize)
def get_radius(self, min_radius_Angstrom: float = 4.0) -> int:
return self._base_models[0].get_radius(min_radius_Angstrom)
def build_complete_map(self) -> torch.Tensor:
"""Mixed electron density: sum_i w_i * density_i."""
fractions = self.fractions
density = None
for i, model in enumerate(self._base_models):
weighted = fractions[i] * model.build_complete_map()
density = weighted if density is None else density + weighted
return density
# ------------------------------------------------------------------
# Forward: weighted structure factors
# ------------------------------------------------------------------
def forward(self, hkl: torch.Tensor, recalc: bool = False) -> torch.Tensor:
"""
Compute f_mixed = sum_i w_i * f_i.
Parameters
----------
hkl : torch.Tensor
Miller indices, shape (n_reflections, 3).
recalc : bool
Force recalculation of structure factors.
Returns
-------
torch.Tensor
Mixed complex structure factors.
"""
fractions = self.fractions
f_mixed = None
for i, model in enumerate(self._base_models):
f_i = model(hkl, recalc=recalc)
weighted_f = fractions[i] * f_i
f_mixed = weighted_f if f_mixed is None else f_mixed + weighted_f
return f_mixed
# ------------------------------------------------------------------
# Freeze / unfreeze
# ------------------------------------------------------------------
def freeze_fractions(self):
self.fraction_params.requires_grad = False
def unfreeze_fractions(self):
self.fraction_params.requires_grad = True
def set_fraction_override(self, fractions: torch.Tensor):
"""Override fractions with an external tensor (e.g. from kinetic model).
While active, ``self.fractions`` returns this tensor instead of
``softmax(fraction_params)``, allowing gradients to flow through
the external source.
"""
self._fraction_override = fractions
def clear_fraction_override(self):
"""Remove fraction override, reverting to softmax(fraction_params)."""
self._fraction_override = None
# ------------------------------------------------------------------
# Convenience
# ------------------------------------------------------------------
def get_vdw_radii(self):
return self._base_models[0].get_vdw_radii()
def xyz(self):
return self._base_models[0].xyz()
def get_individual_fcalc(self, hkl, recalc=True):
return [m(hkl, recalc=recalc) for m in self._base_models]
def __repr__(self):
fracs = self.fractions.detach().tolist()
frac_str = ", ".join(f"{f:.3f}" for f in fracs)
frozen_str = "frozen" if not self.fraction_params.requires_grad else "learnable"
return f"_SharedMixedModel({len(self._base_models)} models, fractions=[{frac_str}], {frozen_str})"
[docs]
class ModelCollection(DeviceMovementMixin, nn.Module):
"""
Named dictionary of MixedModel instances at different timepoints.
All timepoint models share the same base structural models (ModelFT
objects stored once in an nn.ModuleList). Each timepoint gets its own
independent fraction parameters via _SharedMixedModel.
Keys should match DatasetCollection keys so that collection-aware
targets can automatically pair datasets with models.
Parameters
----------
base_models : List[ModelFT]
The K shared structural models (e.g., ground state + intermediates).
dark_key : str
Key for the dark / reference entry. Default ``"dark"``.
verbose : int
Verbosity level.
Examples
--------
::
models = ModelCollection([model_dark, model_light])
models.add_dark() # fractions=[1, 0]
models.add_timepoint("1ps", [0.9, 0.1])
models.add_timepoint("5ps", [0.7, 0.3])
# Access
mixed = models["1ps"]
fcalc = mixed(hkl)
print(mixed.fractions)
"""
[docs]
def __init__(
self,
base_models: List["ModelFT"],
dark_key: str = "dark",
verbose: int = 0,
):
super().__init__()
if not base_models:
raise ValueError("At least one base model is required.")
self._dark_key = dark_key
self.verbose = verbose
# Register base models as owned submodules (single source of truth)
self._base_models = nn.ModuleList(base_models)
# Per-timepoint mixed models (own only fraction params)
self._timepoints = nn.ModuleDict()
self._order: List[str] = []
if self.verbose > 0:
print(
f"ModelCollection initialized with {len(base_models)} base models"
)
# ------------------------------------------------------------------
# Add timepoints
# ------------------------------------------------------------------
[docs]
def add_timepoint(
self,
name: str,
fractions: Optional[List[float]] = None,
frozen_fractions: bool = False,
) -> "ModelCollection":
"""
Add a timepoint with given initial fractions.
Parameters
----------
name : str
Timepoint identifier (should match DatasetCollection key).
fractions : List[float], optional
Initial population fractions. If None, uses equal fractions.
frozen_fractions : bool
If True, fractions are not updated during optimization.
Returns
-------
ModelCollection
Self, for method chaining.
"""
if name in self._timepoints:
raise ValueError(f"Timepoint '{name}' already exists.")
n = len(self._base_models)
if fractions is None:
fractions = [1.0 / n] * n
mixed = _SharedMixedModel(
base_models=list(self._base_models),
initial_fractions=fractions,
frozen_fractions=frozen_fractions,
)
self._timepoints[name] = mixed
self._order.append(name)
if self.verbose > 0:
frac_str = ", ".join(f"{f:.3f}" for f in fractions)
print(f" Added timepoint '{name}': fractions=[{frac_str}]")
return self
[docs]
def add_dark(
self, fractions: Optional[List[float]] = None
) -> "ModelCollection":
"""
Add the dark / reference entry.
Default fractions: [1, 0, 0, ...] (100 % ground state).
Parameters
----------
fractions : List[float], optional
Override dark fractions. Default is pure ground state.
Returns
-------
ModelCollection
Self, for method chaining.
"""
if fractions is None:
n = len(self._base_models)
fractions = [0.0] * n
fractions[0] = 1.0
return self.add_timepoint(self._dark_key, fractions, frozen_fractions=True)
# ------------------------------------------------------------------
# Class methods
# ------------------------------------------------------------------
[docs]
@classmethod
def from_kinetics(
cls,
base_models: List["ModelFT"],
occ_model,
timepoint_names: List[str],
dark_key: str = "dark",
verbose: int = 0,
) -> "ModelCollection":
"""
Create a ModelCollection from a kinetics occupancy model.
Parameters
----------
base_models : List[ModelFT]
Shared structural models.
occ_model : occupancies_kinetics
Kinetic occupancy model whose forward() returns
shape [n_states, n_timepoints].
timepoint_names : List[str]
Names for each timepoint column (excluding dark).
dark_key : str
Key for the dark entry.
verbose : int
Verbosity level.
Returns
-------
ModelCollection
"""
collection = cls(base_models, dark_key=dark_key, verbose=verbose)
collection.add_dark()
with torch.no_grad():
occ = occ_model() # [n_states, n_timepoints]
for t_idx, name in enumerate(timepoint_names):
# +1 because index 0 in occ is the dark timepoint
fracs = occ[:, t_idx + 1].tolist()
collection.add_timepoint(name, fracs)
return collection
# ------------------------------------------------------------------
# IHM I/O
# ------------------------------------------------------------------
[docs]
@classmethod
def from_ihm(
cls,
filepath: str,
max_res: float = 1.5,
radius_angstrom: float = 4.0,
device=None,
verbose: int = 0,
) -> tuple:
"""
Load a ModelCollection from an IHM mmCIF file.
Requires the optional ``python-ihm`` dependency.
Parameters
----------
filepath : str
Path to IHM mmCIF file.
max_res : float
Maximum resolution for FFT grid setup.
radius_angstrom : float
Radius for electron density calculation.
device : torch.device, optional
Device for model tensors.
verbose : int
Verbosity level.
Returns
-------
tuple of (ModelCollection, IHMEnsembleMapping)
"""
from torchref.io.ihm import IHMReader
reader = IHMReader(filepath, verbose=verbose)
return reader(
max_res=max_res,
radius_angstrom=radius_angstrom,
device=device,
)
[docs]
def write_ihm(self, filepath: str, mapping=None, datasets=None) -> None:
"""
Write this ModelCollection to IHM mmCIF format.
Requires the optional ``python-ihm`` dependency.
Parameters
----------
filepath : str
Output file path.
mapping : IHMEnsembleMapping, optional
Mapping with metadata for round-tripping. If ``None``,
a minimal mapping is created from the collection structure.
datasets : dict of str -> ReflectionData, optional
Per-timepoint reflection data to embed in the CIF.
Each key should match a timepoint name.
"""
from torchref.io.ihm import IHMWriter
writer = IHMWriter(
self, mapping=mapping, datasets=datasets, verbose=self.verbose,
)
writer.write(filepath)
# ------------------------------------------------------------------
# Dict-like access
# ------------------------------------------------------------------
def __getitem__(self, name: str) -> "_SharedMixedModel":
return self._timepoints[name]
def __contains__(self, name: str) -> bool:
return name in self._timepoints
def __iter__(self) -> Iterator[Tuple[str, "_SharedMixedModel"]]:
for name in self._order:
yield name, self._timepoints[name]
def __len__(self) -> int:
return len(self._timepoints)
[docs]
def keys(self) -> List[str]:
return list(self._order)
[docs]
def values(self) -> List["_SharedMixedModel"]:
return [self._timepoints[n] for n in self._order]
[docs]
def items(self) -> List[Tuple[str, "_SharedMixedModel"]]:
return [(n, self._timepoints[n]) for n in self._order]
[docs]
def get(self, name: str, default=None):
return self._timepoints.get(name, default)
# ------------------------------------------------------------------
# Convenience properties
# ------------------------------------------------------------------
@property
def dark_key(self) -> str:
return self._dark_key
@property
def dark_model(self) -> "_SharedMixedModel":
"""Shortcut for ``self[dark_key]``."""
return self._timepoints[self._dark_key]
@property
def base_models(self) -> nn.ModuleList:
"""The shared structural models (owned by this collection)."""
return self._base_models
@property
def n_base_models(self) -> int:
return len(self._base_models)
@property
def timepoint_names(self) -> List[str]:
"""All keys except the dark key."""
return [n for n in self._order if n != self._dark_key]
@property
def cell(self):
return self._base_models[0].cell
@property
def spacegroup(self):
return self._base_models[0].spacegroup
@property
def device(self):
return self._base_models[0].device
# ------------------------------------------------------------------
# Fractions inspection
# ------------------------------------------------------------------
[docs]
def get_all_fractions(self) -> Dict[str, torch.Tensor]:
"""Current fractions for each timepoint (including dark)."""
return {name: self._timepoints[name].fractions for name in self._order}
[docs]
def get_fractions_matrix(self) -> torch.Tensor:
"""
All fractions as a matrix [n_timepoints, n_models].
Rows are ordered by ``self._order`` (i.e. insertion order).
"""
return torch.stack(
[self._timepoints[n].fractions for n in self._order], dim=0
)
# ------------------------------------------------------------------
# Freeze / unfreeze helpers
# ------------------------------------------------------------------
[docs]
def freeze_all_fractions(self):
"""Freeze fractions at all timepoints."""
for _, mixed in self:
mixed.freeze_fractions()
[docs]
def unfreeze_all_fractions(self):
"""Unfreeze fractions at all timepoints (except dark)."""
for name, mixed in self:
if name != self._dark_key:
mixed.unfreeze_fractions()
[docs]
def freeze_structures(self):
"""Freeze xyz and adp on all base models."""
for model in self._base_models:
model.freeze("xyz")
model.freeze("b")
[docs]
def unfreeze_structures(self):
"""Unfreeze xyz and adp on all base models."""
for model in self._base_models:
model.unfreeze("xyz")
model.unfreeze("b")
# ------------------------------------------------------------------
# I/O
# ------------------------------------------------------------------
[docs]
def write_pdbs(self, outdir: str):
"""
Write each base model to a PDB file in *outdir*.
Files are named ``base_model_0.pdb``, ``base_model_1.pdb``, etc.
Parameters
----------
outdir : str
Directory to write PDB files into (must exist).
"""
import os
for i, model in enumerate(self._base_models):
path = os.path.join(outdir, f"base_model_{i}.pdb")
model.write_pdb(path)
if self.verbose > 0:
print(f" Wrote {path}")
# ------------------------------------------------------------------
# Repr
# ------------------------------------------------------------------
def __repr__(self):
tp_names = ", ".join(self._order[:4])
if len(self._order) > 4:
tp_names += f", ... ({len(self._order)} total)"
return (
f"ModelCollection({self.n_base_models} base models, "
f"timepoints=[{tp_names}])"
)