torchref.model.mixed_model module
Mixed Model for Time-Resolved Crystallography.
This module provides a MixedModel class that combines multiple ModelFT objects with learnable population fractions, enabling refinement of time-resolved crystallographic data with multiple conformational states.
- class torchref.model.mixed_model.MixedModel(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]
Bases:
DeviceMixin,ModuleModel wrapper combining N ModelFT objects with learnable fractions.
Computes: F_mixed = Σ w_i * F_i where w_i are learnable weights constrained to sum to 1 via softmax.
This is useful for time-resolved crystallography where the crystal contains a mixture of conformational states (e.g., dark and light states) with unknown or refinable population fractions.
- Parameters:
models (List[ModelFT]) – List of ModelFT objects to combine. All models must have compatible cell parameters and space groups.
initial_fractions (List[float], optional) – Initial population fractions for each model. Must sum to 1.0. If None, equal fractions are used (1/N for each model).
frozen_fractions (bool, optional) – If True, fractions are not updated during optimization. Default is False.
verbose (int, optional) – Verbosity level. Default is 0.
- models
Constituent ModelFT objects (proper submodule registration).
- Type:
nn.ModuleList
- fraction_params
Raw parameters for fraction computation (softmax applied).
- Type:
nn.Parameter
Examples
Create a mixed model with two states:
model_dark = ModelFT().load_pdb('dark.pdb') model_light = ModelFT().load_pdb('light.pdb') # 70% dark, 30% light mixed = MixedModel([model_dark, model_light], initial_fractions=[0.7, 0.3]) # Compute mixed structure factors F_mixed = mixed(hkl) # Get current fractions print(mixed.fractions) # tensor([0.7, 0.3])
- __init__(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]
Initialize MixedModel.
- Parameters:
models (List[ModelFT]) – List of ModelFT objects to combine.
initial_fractions (List[float], optional) – Initial population fractions. Must sum to 1.0.
frozen_fractions (bool, optional) – If True, fractions are frozen. Default is False.
verbose (int, optional) – Verbosity level. Default is 0.
device (torch.device, optional) – Device to place the model and parameters on. If None, infers from the first model’s device.
- Raises:
ValueError – If models list is empty, fractions don’t match model count, fractions don’t sum to 1, or models have incompatible parameters.
- property fractions: Tensor
Get normalized population fractions.
- Returns:
Population fractions that sum to 1.0, shape (n_models,).
- Return type:
- property cell
Unit cell from first model (for compatibility).
- property spacegroup
Space group from first model (for compatibility).
- property device
Device from first model (for compatibility).
- property dtype_float
Float dtype from first model (for compatibility).
- property real_space_grid: Tensor | None
Real-space coordinate grid from first model (shared cell → same grid).
- property fft
SfFFT submodule from first model (for gridsize access).
- property map_symmetry
Map symmetry operator from first model.
- setup_grid(max_res=None, gridsize=None)[source]
Setup real-space grid on all constituent models.
All models share the same cell/spacegroup, so the grid is identical across all of them. This ensures each model’s SfFFT is ready for density map calculations.
- get_radius(min_radius_Angstrom=4.0)[source]
Get the radius in voxels for density calculation.
Delegates to first model (same grid → same voxel size).
- build_complete_map()[source]
Build the mixed electron density map as the weighted sum of constituent model density maps.
density_mixed = Σ w_i * density_i
Each constituent model builds its own density map on the shared grid, and the results are combined using the current population fractions.
- Returns:
Electron density map with shape (nx, ny, nz).
- Return type:
- freeze_fractions()[source]
Exclude fractions from optimization.
This prevents the population fractions from being updated during training while still allowing the constituent models to be refined.
- unfreeze_fractions()[source]
Include fractions in optimization.
This allows the population fractions to be updated during training.
- forward(hkl, recalc=False)[source]
Compute weighted sum of structure factors from all models.
f_mixed = Σ w_i * f_i
- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, force recalculation of structure factors. Default is False.
- Returns:
Mixed complex structure factors with shape (n_reflections,).
- Return type:
- get_individual_fcalc(hkl, recalc=True)[source]
Get structure factors from each model individually.
- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, force recalculation. Default is True.
- Returns:
List of structure factor tensors, one per model.
- Return type:
List[torch.Tensor]
- copy()[source]
Create a deep copy of the MixedModel.
- Returns:
A new MixedModel instance with copied models and parameters.
- Return type:
- write_ihm(filepath, state_names=None, group_name='ensemble', datasets=None)[source]
Write this MixedModel to an IHM mmCIF file.
Creates a single model group with the current population fractions over the constituent structural states.
Requires the optional
python-ihmdependency.- Parameters:
filepath (str) – Output file path.
state_names (list of str, optional) – Names for each constituent model / state. Default:
state_1,state_2, …group_name (str) – Name for the model group. Default
"ensemble".datasets (dict of str -> ReflectionData, optional) – Per-timepoint reflection data to embed in the CIF.
- get_vdw_radii()[source]
Get van der Waals radii from the first model.
- Returns:
Van der Waals radii tensor.
- Return type: