Source code for torchref.model.mixed_model

"""
Mixed Model for Time-Resolved Crystallography.

This module provides a MixedModel class that combines multiple ModelFT objects
with learnable population fractions, enabling refinement of time-resolved
crystallographic data with multiple conformational states.
"""

from typing import TYPE_CHECKING, List, Optional

import torch
from torch import nn

from torchref.utils.device_mixin import DeviceMovementMixin

if TYPE_CHECKING:
    from torchref.model.model_ft import ModelFT


[docs] class MixedModel(DeviceMovementMixin, nn.Module): """ Model wrapper combining N ModelFT objects with learnable fractions. Computes: F_mixed = Σ w_i * F_i where w_i are learnable weights constrained to sum to 1 via softmax. This is useful for time-resolved crystallography where the crystal contains a mixture of conformational states (e.g., dark and light states) with unknown or refinable population fractions. Parameters ---------- models : List[ModelFT] List of ModelFT objects to combine. All models must have compatible cell parameters and space groups. initial_fractions : List[float], optional Initial population fractions for each model. Must sum to 1.0. If None, equal fractions are used (1/N for each model). frozen_fractions : bool, optional If True, fractions are not updated during optimization. Default is False. verbose : int, optional Verbosity level. Default is 0. Attributes ---------- models : nn.ModuleList Constituent ModelFT objects (proper submodule registration). fraction_params : nn.Parameter Raw parameters for fraction computation (softmax applied). Examples -------- Create a mixed model with two states:: model_dark = ModelFT().load_pdb('dark.pdb') model_light = ModelFT().load_pdb('light.pdb') # 70% dark, 30% light mixed = MixedModel([model_dark, model_light], initial_fractions=[0.7, 0.3]) # Compute mixed structure factors F_mixed = mixed(hkl) # Get current fractions print(mixed.fractions) # tensor([0.7, 0.3]) """
[docs] def __init__( self, models: List["ModelFT"], initial_fractions: Optional[List[float]] = None, frozen_fractions: bool = False, verbose: int = 0, device: Optional[torch.device] = None, ): """ Initialize MixedModel. Parameters ---------- models : List[ModelFT] List of ModelFT objects to combine. initial_fractions : List[float], optional Initial population fractions. Must sum to 1.0. frozen_fractions : bool, optional If True, fractions are frozen. Default is False. verbose : int, optional Verbosity level. Default is 0. device : torch.device, optional Device to place the model and parameters on. If None, infers from the first model's device. Raises ------ ValueError If models list is empty, fractions don't match model count, fractions don't sum to 1, or models have incompatible parameters. """ super().__init__() if not models: raise ValueError("At least one model is required.") self.verbose = verbose # Infer device from first model if not specified if device is None: device = models[0].device # Store models as ModuleList for proper PyTorch handling models = [model.to(device=device) for model in models] self.models = nn.ModuleList(models) # Validate model compatibility self._validate_models() # Initialize fractions n_models = len(models) if initial_fractions is None: initial_fractions = [1.0 / n_models] * n_models else: if len(initial_fractions) != n_models: raise ValueError( f"Number of fractions ({len(initial_fractions)}) must match " f"number of models ({n_models})." ) total = sum(initial_fractions) if abs(total - 1.0) > 1e-6: raise ValueError( f"Initial fractions must sum to 1.0, got {total:.6f}." ) # Use inverse softmax to initialize parameters # softmax(theta) = fractions, so theta = log(fractions) fractions_tensor = torch.tensor(initial_fractions, dtype=torch.float32, device=device) theta = torch.log(fractions_tensor.clamp(min=1e-6)) self.fraction_params = nn.Parameter(theta, requires_grad=not frozen_fractions) if self.verbose > 0: print(f"MixedModel initialized with {n_models} models") print(f" Initial fractions: {initial_fractions}") print(f" Fractions frozen: {frozen_fractions}")
def _validate_models(self): """ Validate that all models have compatible cell and spacegroup. Raises ------ ValueError If models have incompatible parameters. """ if len(self.models) < 2: return # Single model always compatible with itself ref_model = self.models[0] ref_cell = ref_model.cell ref_sg = ref_model.spacegroup for i, model in enumerate(self.models[1:], start=1): # Check cell compatibility (allow small tolerance) if ref_cell is not None and model.cell is not None: assert torch.allclose( ref_model.cell.data, ref_cell.data,atol=1, rtol=0.01) # Check spacegroup compatibility if ref_sg is not None and model.spacegroup is not None: if ref_sg.number != model.spacegroup.number: raise ValueError( f"Model {i} has incompatible spacegroup. " f"Reference: {ref_sg.number}, Model {i}: {model.spacegroup.number}" ) @property def fractions(self) -> torch.Tensor: """ Get normalized population fractions. Returns ------- torch.Tensor Population fractions that sum to 1.0, shape (n_models,). """ return torch.softmax(self.fraction_params, dim=0) @property def cell(self): """Unit cell from first model (for compatibility).""" return self.models[0].cell @property def spacegroup(self): """Space group from first model (for compatibility).""" return self.models[0].spacegroup @property def device(self): """Device from first model (for compatibility).""" return self.models[0].device @property def dtype_float(self): """Float dtype from first model (for compatibility).""" return self.models[0].dtype_float # ========================================================================= # Grid infrastructure (delegates to constituent models) # ========================================================================= @property def real_space_grid(self) -> Optional[torch.Tensor]: """Real-space coordinate grid from first model (shared cell → same grid).""" return self.models[0].real_space_grid @property def fft(self): """SfFFT submodule from first model (for gridsize access).""" return self.models[0].fft @property def gridsize(self) -> Optional[torch.Tensor]: """Grid dimensions (nx, ny, nz) from first model.""" return self.models[0].gridsize @property def map_symmetry(self): """Map symmetry operator from first model.""" return self.models[0].map_symmetry @property def inv_fractional_matrix(self) -> torch.Tensor: """Inverse fractionalization (orthogonalization) matrix.""" return self.cell.inv_fractional_matrix.to(dtype=self.dtype_float) @property def fractional_matrix(self) -> torch.Tensor: """Fractionalization matrix.""" return self.cell.fractional_matrix.to(dtype=self.dtype_float)
[docs] def setup_grid(self, max_res=None, gridsize=None): """ Setup real-space grid on all constituent models. All models share the same cell/spacegroup, so the grid is identical across all of them. This ensures each model's SfFFT is ready for density map calculations. Parameters ---------- max_res : float, optional Maximum resolution for grid spacing in Angstroms. gridsize : tuple of int, optional Explicit grid size (nx, ny, nz). """ for model in self.models: model.setup_grid(max_res=max_res, gridsize=gridsize)
[docs] def get_radius(self, min_radius_Angstrom: float = 4.0) -> int: """ Get the radius in voxels for density calculation. Delegates to first model (same grid → same voxel size). Parameters ---------- min_radius_Angstrom : float, optional Minimum radius in Angstroms. Default is 4.0. Returns ------- int Radius in voxels. """ return self.models[0].get_radius(min_radius_Angstrom)
[docs] def build_complete_map(self) -> torch.Tensor: """ Build the mixed electron density map as the weighted sum of constituent model density maps. density_mixed = Σ w_i * density_i Each constituent model builds its own density map on the shared grid, and the results are combined using the current population fractions. Returns ------- torch.Tensor Electron density map with shape (nx, ny, nz). """ fractions = self.fractions density = None for i, model in enumerate(self.models): model_density = model.build_complete_map() weighted = fractions[i] * model_density if density is None: density = weighted else: density = density + weighted return density
[docs] def freeze_fractions(self): """ Exclude fractions from optimization. This prevents the population fractions from being updated during training while still allowing the constituent models to be refined. """ self.fraction_params.requires_grad = False if self.verbose > 0: print("Fractions frozen")
[docs] def unfreeze_fractions(self): """ Include fractions in optimization. This allows the population fractions to be updated during training. """ self.fraction_params.requires_grad = True if self.verbose > 0: print("Fractions unfrozen")
[docs] def forward(self, hkl: torch.Tensor, recalc: bool = False) -> torch.Tensor: """ Compute weighted sum of structure factors from all models. f_mixed = Σ w_i * f_i Parameters ---------- hkl : torch.Tensor Miller indices with shape (n_reflections, 3). recalc : bool, optional If True, force recalculation of structure factors. Default is False. Returns ------- torch.Tensor Mixed complex structure factors with shape (n_reflections,). """ fractions = self.fractions # Compute structure factors from each model and weight them f_mixed = None for i, model in enumerate(self.models): f_i = model(hkl, recalc=recalc) weighted_f = fractions[i] * f_i if f_mixed is None: f_mixed = weighted_f else: f_mixed = f_mixed + weighted_f if self.verbose > 2: print(f"MixedModel forward: fractions = {fractions.detach().tolist()}") return f_mixed
[docs] def get_individual_fcalc( self, hkl: torch.Tensor, recalc: bool = True ) -> List[torch.Tensor]: """ Get structure factors from each model individually. Parameters ---------- hkl : torch.Tensor Miller indices with shape (n_reflections, 3). recalc : bool, optional If True, force recalculation. Default is True. Returns ------- List[torch.Tensor] List of structure factor tensors, one per model. """ return [model(hkl, recalc=recalc) for model in self.models]
[docs] def copy(self) -> "MixedModel": """ Create a deep copy of the MixedModel. Returns ------- MixedModel A new MixedModel instance with copied models and parameters. """ # Deep copy each constituent model copied_models = [model.copy() for model in self.models] # Get current fractions with torch.no_grad(): current_fractions = self.fractions.tolist() # Create new MixedModel copied = MixedModel( models=copied_models, initial_fractions=current_fractions, frozen_fractions=not self.fraction_params.requires_grad, verbose=self.verbose, ) return copied
[docs] def __repr__(self) -> str: """String representation.""" fracs = self.fractions.detach().tolist() frac_str = ", ".join([f"{f:.3f}" for f in fracs]) frozen_str = "frozen" if not self.fraction_params.requires_grad else "learnable" return f"MixedModel({len(self.models)} models, fractions=[{frac_str}], {frozen_str})"
[docs] def write_ihm( self, filepath: str, state_names: Optional[List[str]] = None, group_name: str = "ensemble", datasets: Optional[dict] = None, ) -> None: """ Write this MixedModel to an IHM mmCIF file. Creates a single model group with the current population fractions over the constituent structural states. Requires the optional ``python-ihm`` dependency. Parameters ---------- filepath : str Output file path. state_names : list of str, optional Names for each constituent model / state. Default: ``state_1``, ``state_2``, ... group_name : str Name for the model group. Default ``"ensemble"``. datasets : dict of str -> ReflectionData, optional Per-timepoint reflection data to embed in the CIF. """ from torchref.io.ihm import IHMWriter from torchref.io.ihm_mapping import ( IHMEnsembleMapping, IHMModelGroupInfo, IHMStateInfo, ) n = len(self.models) fracs = self.fractions.detach().cpu().tolist() # Build states states = [] for i in range(n): name = state_names[i] if state_names and i < len(state_names) else f"state_{i + 1}" states.append( IHMStateInfo(state_id=i + 1, name=name, model_num=i + 1) ) # Build single model group state_ids = [s.state_id for s in states] state_fractions = dict(zip(state_ids, fracs)) groups = [ IHMModelGroupInfo( group_id=1, name=group_name, state_fractions=state_fractions, ) ] # Extract cell/spacegroup from first model cell = None spacegroup = None model0 = self.models[0] if hasattr(model0, "cell") and model0.cell is not None: cell_obj = model0.cell if hasattr(cell_obj, "tolist"): cell = cell_obj.tolist() elif hasattr(cell_obj, "parameters"): cell = cell_obj.parameters.tolist() if hasattr(model0, "spacegroup") and model0.spacegroup is not None: sg = model0.spacegroup if hasattr(sg, "hm"): spacegroup = sg.hm elif hasattr(sg, "xhm"): spacegroup = sg.xhm() else: spacegroup = str(sg) mapping = IHMEnsembleMapping( states=states, model_groups=groups, cell=cell, spacegroup=spacegroup, ) # Create a temporary ModelCollection-like wrapper for the writer writer = IHMWriter._from_mixed_model(self, mapping) writer.datasets = datasets writer.write(filepath)
[docs] def get_vdw_radii(self) -> torch.Tensor: """ Get van der Waals radii from the first model. Returns ------- torch.Tensor Van der Waals radii tensor. """ return self.models[0].get_vdw_radii()
[docs] def xyz(self) -> torch.Tensor: """ Get atomic coordinates from the first model. Returns ------- torch.Tensor Atomic coordinates tensor. """ return self.models[0].xyz()