torchref.model.mixed_model module

Mixed Model for Time-Resolved Crystallography.

This module provides a MixedModel class that combines multiple ModelFT objects with learnable population fractions, enabling refinement of time-resolved crystallographic data with multiple conformational states.

class torchref.model.mixed_model.MixedModel(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]

Bases: DeviceMixin, Module

Model wrapper combining N ModelFT objects with learnable fractions.

Computes: F_mixed = Σ w_i * F_i where w_i are learnable weights constrained to sum to 1 via softmax.

This is useful for time-resolved crystallography where the crystal contains a mixture of conformational states (e.g., dark and light states) with unknown or refinable population fractions.

Parameters:
  • models (List[ModelFT]) – List of ModelFT objects to combine. All models must have compatible cell parameters and space groups.

  • initial_fractions (List[float], optional) – Initial population fractions for each model. Must sum to 1.0. If None, equal fractions are used (1/N for each model).

  • frozen_fractions (bool, optional) – If True, fractions are not updated during optimization. Default is False.

  • verbose (int, optional) – Verbosity level. Default is 0.

models

Constituent ModelFT objects (proper submodule registration).

Type:

nn.ModuleList

fraction_params

Raw parameters for fraction computation (softmax applied).

Type:

nn.Parameter

Examples

Create a mixed model with two states:

model_dark = ModelFT().load_pdb('dark.pdb')
model_light = ModelFT().load_pdb('light.pdb')

# 70% dark, 30% light
mixed = MixedModel([model_dark, model_light], initial_fractions=[0.7, 0.3])

# Compute mixed structure factors
F_mixed = mixed(hkl)

# Get current fractions
print(mixed.fractions)  # tensor([0.7, 0.3])
__init__(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]

Initialize MixedModel.

Parameters:
  • models (List[ModelFT]) – List of ModelFT objects to combine.

  • initial_fractions (List[float], optional) – Initial population fractions. Must sum to 1.0.

  • frozen_fractions (bool, optional) – If True, fractions are frozen. Default is False.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • device (torch.device, optional) – Device to place the model and parameters on. If None, infers from the first model’s device.

Raises:

ValueError – If models list is empty, fractions don’t match model count, fractions don’t sum to 1, or models have incompatible parameters.

property fractions: Tensor

Get normalized population fractions.

Returns:

Population fractions that sum to 1.0, shape (n_models,).

Return type:

torch.Tensor

property cell

Unit cell from first model (for compatibility).

property spacegroup

Space group from first model (for compatibility).

property device

Device from first model (for compatibility).

property dtype_float

Float dtype from first model (for compatibility).

property real_space_grid: Tensor | None

Real-space coordinate grid from first model (shared cell → same grid).

property fft

SfFFT submodule from first model (for gridsize access).

property gridsize: Tensor | None

Grid dimensions (nx, ny, nz) from first model.

property map_symmetry

Map symmetry operator from first model.

property inv_fractional_matrix: Tensor

Inverse fractionalization (orthogonalization) matrix.

property fractional_matrix: Tensor

Fractionalization matrix.

setup_grid(max_res=None, gridsize=None)[source]

Setup real-space grid on all constituent models.

All models share the same cell/spacegroup, so the grid is identical across all of them. This ensures each model’s SfFFT is ready for density map calculations.

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms.

  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz).

get_radius(min_radius_Angstrom=4.0)[source]

Get the radius in voxels for density calculation.

Delegates to first model (same grid → same voxel size).

Parameters:

min_radius_Angstrom (float, optional) – Minimum radius in Angstroms. Default is 4.0.

Returns:

Radius in voxels.

Return type:

int

build_complete_map()[source]

Build the mixed electron density map as the weighted sum of constituent model density maps.

density_mixed = Σ w_i * density_i

Each constituent model builds its own density map on the shared grid, and the results are combined using the current population fractions.

Returns:

Electron density map with shape (nx, ny, nz).

Return type:

torch.Tensor

freeze_fractions()[source]

Exclude fractions from optimization.

This prevents the population fractions from being updated during training while still allowing the constituent models to be refined.

unfreeze_fractions()[source]

Include fractions in optimization.

This allows the population fractions to be updated during training.

forward(hkl, recalc=False)[source]

Compute weighted sum of structure factors from all models.

f_mixed = Σ w_i * f_i

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • recalc (bool, optional) – If True, force recalculation of structure factors. Default is False.

Returns:

Mixed complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

get_individual_fcalc(hkl, recalc=True)[source]

Get structure factors from each model individually.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • recalc (bool, optional) – If True, force recalculation. Default is True.

Returns:

List of structure factor tensors, one per model.

Return type:

List[torch.Tensor]

copy()[source]

Create a deep copy of the MixedModel.

Returns:

A new MixedModel instance with copied models and parameters.

Return type:

MixedModel

__repr__()[source]

String representation.

write_ihm(filepath, state_names=None, group_name='ensemble', datasets=None)[source]

Write this MixedModel to an IHM mmCIF file.

Creates a single model group with the current population fractions over the constituent structural states.

Requires the optional python-ihm dependency.

Parameters:
  • filepath (str) – Output file path.

  • state_names (list of str, optional) – Names for each constituent model / state. Default: state_1, state_2, …

  • group_name (str) – Name for the model group. Default "ensemble".

  • datasets (dict of str -> ReflectionData, optional) – Per-timepoint reflection data to embed in the CIF.

get_vdw_radii()[source]

Get van der Waals radii from the first model.

Returns:

Van der Waals radii tensor.

Return type:

torch.Tensor

xyz()[source]

Get atomic coordinates from the first model.

Returns:

Atomic coordinates tensor.

Return type:

torch.Tensor