torchref.model package
Atomic structure model module for TorchRef.
This module provides PyTorch nn.Module-based representations of crystallographic atomic models, including coordinates, B-factors, occupancies, and anisotropic displacement parameters.
Classes
- Model
Base atomic model storing xyz coordinates, B-factors, occupancies.
- ModelFT
Fourier Transform model for FFT-based structure factor calculation.
- SfFFT
Structure Factor calculator using FFT (Fast Fourier Transform).
- SfDS
Structure Factor calculator using Direct Summation.
- FFT
Backward compatibility alias for SfFFT.
- MixedTensor
Hybrid tensor allowing partial freezing of parameters.
- PositiveMixedTensor
MixedTensor with positivity constraint.
- PassThroughTensor
Direct parameter access wrapper.
- OccupancyTensor
Tensor constrained to [0, 1] range for occupancies.
Example
from torchref.model import Model, ModelFT, MixedTensor, SfFFT, SfDS
# Load model from PDB
model = Model()
model.load_pdb('structure.pdb')
# Access coordinates and B-factors
xyz = model.xyz # (N, 3) tensor
b = model.b # (N,) tensor
# Use ModelFT for FFT-based structure factors
model_ft = ModelFT(data, device='cuda')
F_calc = model_ft.get_F_calc()
# Use SfFFT standalone for custom workflows
sf_fft = SfFFT(max_res=1.5)
sf_fft.setup_grid(cell, spacegroup)
sf = sf_fft.map_to_structure_factors(density_map, hkl)
# Use SfDS for direct summation
sf_ds = SfDS(cell, spacegroup)
sf, _ = sf_ds.compute_structure_factors(hkl, xyz, adp, occ, A, B)
- class torchref.model.SfFFT(cell=None, spacegroup=None, max_res=1.5, radius_angstrom=3.0, dtype_float=torch.float32, device=None, verbose=0, use_late_symmetry=True)[source]
Bases:
DeviceMixin,ModuleStructure Factor calculator using FFT (Fast Fourier Transform).
This module encapsulates all FFT-related functionality for computing electron density maps and structure factors. It is initialized with a Cell and optionally a SpaceGroup, which are used for grid calculations.
- Parameters:
cell (Cell) – Unit cell object containing cell parameters.
spacegroup (SpaceGroupLike, optional) – Space group specification (string, int, or gemmi.SpaceGroup). If None, defaults to P1.
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level for logging. Default is 0.
- spacegroup
Space group object (SpaceGroup nn.Module with matrices and translations).
- Type:
- symmetry
Alias for spacegroup (backward compatibility).
- Type:
- gridsize
Grid dimensions (nx, ny, nz) when grid is set up.
- Type:
torch.Tensor or None
- real_space_grid
Real-space coordinate grid with shape (nx, ny, nz, 3).
- Type:
torch.Tensor or None
- voxel_size
Voxel dimensions.
- Type:
torch.Tensor or None
- map_symmetry
Symmetry operator for map calculations.
- Type:
MapSymmetry or None
Examples
Standalone usage:
from torchref.symmetry import Cell cell = Cell([50, 60, 70, 90, 90, 90]) sf_fft = SfFFT(cell, spacegroup='P212121', max_res=1.5) sf_fft.setup_grid() density_map = sf_fft.build_density_map(xyz, b, occ, A, B, inv_frac, frac) sf = sf_fft.map_to_structure_factors(density_map, hkl)
With ModelFT (composition):
model = ModelFT() model.load_pdb('structure.pdb') sf = model.get_structure_factor(hkl) # Uses internal SfFFT instance
- __init__(cell=None, spacegroup=None, max_res=1.5, radius_angstrom=3.0, dtype_float=torch.float32, device=None, verbose=0, use_late_symmetry=True)[source]
Initialize the SfFFT module with cell and spacegroup.
- Parameters:
cell (Cell, optional) – Unit cell object. If None, must be set later via set_cell().
spacegroup (SpaceGroupLike, optional) – Space group specification. If None, defaults to P1.
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.5.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 3.0.
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.
device (torch.device, optional) – Computation device. Default is None (uses cell’s device). If Cell is also None, defaults to CPU.
verbose (int, optional) – Verbosity level for logging. Default is 0.
use_late_symmetry (bool, optional) – If True (default), apply symmetry in reciprocal space after FFT (“late symmetry”) for faster structure factor calculation (~5x speedup). If False, apply symmetry to density map before FFT (“early symmetry”).
- property spacegroup: SpaceGroup | None
Space group object (SpaceGroup nn.Module).
- property symmetry: SpaceGroup | None
Symmetry operations handler (alias for spacegroup).
- property fractional_matrix: Tensor | None
Get fractionalization matrix from cell, on this module’s device/dtype.
- property inv_fractional_matrix: Tensor | None
Get orthogonalization matrix from cell, on this module’s device/dtype.
- set_cell_and_spacegroup(cell, spacegroup=None)[source]
Set cell and spacegroup for this SfFFT instance.
- Parameters:
cell (Cell) – Unit cell object.
spacegroup (SpaceGroupLike, optional) – Space group specification.
- compute_optimal_gridsize(max_res=None)[source]
Compute optimal grid dimensions using the stored cell and spacegroup.
Uses Cell.compute_grid_size() for base calculation and Symmetry.suggest_grid_size() for symmetry optimization.
- Parameters:
max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.
- Returns:
Optimal grid dimensions (nx, ny, nz).
- Return type:
- Raises:
RuntimeError – If cell has not been set.
- static compute_real_space_grid(fractional_matrix, gridsize, device=device(type='cpu'))[source]
Generate the real-space coordinate grid.
- Parameters:
cell_data (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma].
gridsize (torch.Tensor) – Grid dimensions (nx, ny, nz).
device (torch.device, optional) – Target device. Default is CPU.
- Returns:
Real-space grid with shape (nx, ny, nz, 3).
- Return type:
- setup_grid(gridsize=None, max_res=None)[source]
Setup the real-space grid for electron density calculation.
This method initializes and stores the grid state for subsequent density map calculations. Uses the stored cell and spacegroup.
- Parameters:
- Raises:
RuntimeError – If cell has not been set.
- build_density_map(xyz_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, apply_symmetry=True)[source]
Build electron density map from atomic parameters.
This method requires setup_grid() to have been called first.
- Parameters:
xyz_iso (torch.Tensor) – Isotropic atom coordinates with shape (n_iso, 3).
adp_iso (torch.Tensor) – Isotropic ADPs (atomic displacement parameters) with shape (n_iso,).
occ_iso (torch.Tensor) – Isotropic occupancies with shape (n_iso,).
A_iso (torch.Tensor) – ITC92 A parameters for isotropic atoms with shape (n_iso, 5).
B_iso (torch.Tensor) – ITC92 B parameters for isotropic atoms with shape (n_iso, 5).
xyz_aniso (torch.Tensor, optional) – Anisotropic atom coordinates with shape (n_aniso, 3).
u_aniso (torch.Tensor, optional) – Anisotropic U parameters with shape (n_aniso, 6).
occ_aniso (torch.Tensor, optional) – Anisotropic occupancies with shape (n_aniso,).
A_aniso (torch.Tensor, optional) – ITC92 A parameters for anisotropic atoms with shape (n_aniso, 5).
B_aniso (torch.Tensor, optional) – ITC92 B parameters for anisotropic atoms with shape (n_aniso, 5).
apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.
- Returns:
Electron density map with shape (nx, ny, nz).
- Return type:
- Raises:
RuntimeError – If setup_grid() has not been called.
- map_to_structure_factors(density_map, hkl, apply_symmetry=True)[source]
Convert density map to structure factors via FFT.
- Parameters:
density_map (torch.Tensor) – Electron density map with shape (nx, ny, nz). If apply_symmetry=True, this should be a P1 density map.
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
apply_symmetry (bool, optional) – If True and late symmetry is enabled/compatible, apply symmetry in reciprocal space. Default is False (assume map already has symmetry applied or use early symmetry path).
- Returns:
Complex structure factors with shape (n_reflections,).
- Return type:
- compute_structure_factors(hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, apply_symmetry=True)[source]
Compute structure factors from atomic parameters (end-to-end).
This is a convenience method that builds the density map and computes structure factors in one call.
When use_late_symmetry=True (default) and the grid is compatible, symmetry is applied in reciprocal space after FFT for ~5x speedup. Otherwise, symmetry is applied to the density map before FFT.
- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
xyz_iso (torch.Tensor) – Isotropic atom coordinates with shape (n_iso, 3).
adp_iso (torch.Tensor) – Isotropic ADPs (atomic displacement parameters) with shape (n_iso,).
occ_iso (torch.Tensor) – Isotropic occupancies with shape (n_iso,).
A_iso (torch.Tensor) – ITC92 A parameters for isotropic atoms.
B_iso (torch.Tensor) – ITC92 B parameters for isotropic atoms.
xyz_aniso (torch.Tensor, optional) – Anisotropic atom coordinates.
u_aniso (torch.Tensor, optional) – Anisotropic U parameters.
occ_aniso (torch.Tensor, optional) – Anisotropic occupancies.
A_aniso (torch.Tensor, optional) – ITC92 A parameters for anisotropic atoms.
B_aniso (torch.Tensor, optional) – ITC92 B parameters for anisotropic atoms.
apply_symmetry (bool, optional) – If True, apply crystallographic symmetry. Default is True.
- Returns:
sf (torch.Tensor) – Complex structure factors with shape (n_reflections,).
density_map (torch.Tensor) – Electron density map with shape (nx, ny, nz). Note: When using late symmetry, this is the P1 map (without symmetry).
- Return type:
- class torchref.model.SfDS(cell=None, spacegroup=None, dtype_float=torch.float32, device=device(type='cpu'), verbose=0, max_memory_gb=2.0)[source]
Bases:
DeviceMixin,ModuleStructure Factor calculator using Direct Summation.
This module computes structure factors by directly summing atomic contributions without building an intermediate electron density map. It is initialized with a Cell and optionally a SpaceGroup.
Includes automatic batching to handle memory constraints for large structures or high-resolution data.
- Parameters:
cell (Cell, optional) – Unit cell object containing cell parameters.
spacegroup (SpaceGroupLike, optional) – Space group specification (string, int, or gemmi.SpaceGroup). If None, defaults to P1.
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level for logging. Default is 0.
max_memory_gb (float, optional) – Maximum memory to use for intermediate tensors in GB. Default is 2.0. Set to None to disable batching.
- spacegroup
Space group object (SpaceGroup nn.Module with matrices and translations).
- Type:
Examples
Standalone usage:
from torchref.symmetry import Cell cell = Cell([50, 60, 70, 90, 90, 90]) sf_ds = SfDS(cell, spacegroup='P212121') sf, _ = sf_ds.compute_structure_factors( hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso )
With memory limit for large structures:
sf_ds = SfDS(cell, spacegroup='P212121', max_memory_gb=4.0) sf, _ = sf_ds.compute_structure_factors(...) # Auto-batches if needed
Notes
Key differences from SfFFT: - No grid setup required - No build_density_map() or map_to_structure_factors() methods - Computes scattering factors internally from A/B ITC92 coefficients - Returns (sf, None) instead of (sf, density_map) - Automatic batching for memory management
- __init__(cell=None, spacegroup=None, dtype_float=torch.float32, device=device(type='cpu'), verbose=0, max_memory_gb=2.0)[source]
Initialize the SfDS module with cell and spacegroup.
- Parameters:
cell (Cell, optional) – Unit cell object. If None, must be set later.
spacegroup (SpaceGroupLike, optional) – Space group specification. If None, defaults to P1.
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level for logging. Default is 0.
max_memory_gb (float, optional) – Maximum memory for intermediate tensors in GB. Default is 2.0.
- property spacegroup: SpaceGroup | None
Space group object (SpaceGroup nn.Module).
- set_cell_and_spacegroup(cell, spacegroup=None)[source]
Set cell and spacegroup for this SfDS instance.
- Parameters:
cell (Cell) – Unit cell object.
spacegroup (SpaceGroupLike, optional) – Space group specification.
- compute_structure_factors(hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, apply_symmetry=True)[source]
Compute structure factors from atomic parameters using direct summation.
Uses “late symmetry” approach (same as SfFFT): first computes P1 structure factors at symmetry-equivalent HKLs, then combines them with phase shifts.
- The symmetry formula is:
F_sym(h) = Σ_ops exp(2πi h.t) * F_P1(R^T @ h)
- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
xyz_iso (torch.Tensor) – Isotropic atom coordinates (Cartesian) with shape (n_iso, 3).
adp_iso (torch.Tensor) – Isotropic ADPs (atomic displacement parameters) with shape (n_iso,).
occ_iso (torch.Tensor) – Isotropic occupancies with shape (n_iso,).
A_iso (torch.Tensor) – ITC92 A parameters for isotropic atoms with shape (n_iso, 5).
B_iso (torch.Tensor) – ITC92 B parameters for isotropic atoms with shape (n_iso, 5).
xyz_aniso (torch.Tensor, optional) – Anisotropic atom coordinates (Cartesian) with shape (n_aniso, 3).
u_aniso (torch.Tensor, optional) – Anisotropic U parameters with shape (n_aniso, 6).
occ_aniso (torch.Tensor, optional) – Anisotropic occupancies with shape (n_aniso,).
A_aniso (torch.Tensor, optional) – ITC92 A parameters for anisotropic atoms with shape (n_aniso, 5).
B_aniso (torch.Tensor, optional) – ITC92 B parameters for anisotropic atoms with shape (n_aniso, 5).
apply_symmetry (bool, optional) – If True, apply crystallographic symmetry. Default is True.
- Returns:
sf (torch.Tensor) – Complex structure factors with shape (n_reflections,).
None – Second return value is None (for API compatibility with SfFFT).
- Return type:
- class torchref.model.InternalCoordinateTensor(initial_xyz, bond_cutoff=2.0, requires_grad=True, dtype=None, device=None)[source]
Bases:
DeviceMixin,ModuleParameter wrapper using internal coordinates (Z-matrix style).
Stores: bond_lengths, angles, torsions, chain_positions, chain_orientations Reconstructs: Cartesian xyz on forward()
This provides a physically meaningful parametrization of atomic coordinates where perturbations correspond to changes in bond lengths, angles, and torsion angles rather than arbitrary Cartesian displacements.
- Parameters:
initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.
requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.
device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.
- bond_lengths
Bond length parameters in Angstroms.
- Type:
nn.Parameter
- angles
Angle parameters in radians.
- Type:
nn.Parameter
- torsions
Torsion angle parameters in radians.
- Type:
nn.Parameter
- chain_positions
Absolute positions of chain root atoms.
- Type:
nn.Parameter
- chain_orientations
Axis-angle orientations for each chain.
- Type:
nn.Parameter
- __init__(initial_xyz, bond_cutoff=2.0, requires_grad=True, dtype=None, device=None)[source]
Initialize InternalCoordinateTensor from Cartesian coordinates.
- Parameters:
initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.
requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.
device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.
- property dtype
Return the dtype of tensors.
- property device
Logical device — where forward()’s result is delivered.
Internal parameters/buffers stay on CPU regardless; this is the device requested by the caller (e.g. via
.to('mps')) and is the device the forward output is migrated to.
- to(*args, **kwargs)[source]
Update output device and optionally cast dtype.
Unlike
DeviceMixin.to, this does not move internal parameters/buffers todevice— they stay on CPU to avoid the per-op dispatch overhead of MPS/CUDA on the sequential spanning-tree + parallel-scan code. Thedeviceargument only updates_output_device;dtypestill propagates normally and recasts all CPU tensors.
- cuda(device=None)[source]
Move all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on GPU while being optimized.
Note
This method modifies the module in-place.
- Args:
- device (int, optional): if specified, all parameters will be
copied to that device
- Returns:
Module: self
- cpu()[source]
Move all model parameters and buffers to the CPU.
Note
This method modifies the module in-place.
- Returns:
Module: self
- forward_slow()[source]
Reconstruct Cartesian xyz from internal coordinates.
Fully vectorized - processes each depth level in parallel. Only log(max_depth) sequential steps required.
- Returns:
Reconstructed Cartesian coordinates of shape (N, 3).
- Return type:
- forward()[source]
Reconstruct Cartesian xyz from internal coordinates.
Uses optimized parallel scan method for efficiency.
- Returns:
Reconstructed Cartesian coordinates of shape (N, 3), on the configured output device.
- Return type:
- shake(magnitude=0.1)[source]
Add Gaussian noise to internal parameters (fully vectorized).
All operations are batched tensor ops - no loops.
- Parameters:
magnitude (float, optional) – Standard deviation of Gaussian noise. Default is 0.1. For bond lengths, this is in Angstroms. For angles and torsions, this is in radians.
- Returns:
New Cartesian coordinates after perturbation.
- Return type:
- fix(selection=None, freeze_at_current=True)[source]
Fix (freeze) atoms to use fixed xyz coordinates instead of internal coordinates.
Fixed atoms will not be updated during reconstruction from internal coordinates. Their positions will remain at the stored fixed_xyz values.
- Parameters:
selection (torch.Tensor, slice, or None) – Boolean mask (shape n_atoms) or indices of atoms to fix. If None, fixes all atoms.
freeze_at_current (bool, optional) – If True (default), store current reconstructed xyz for the selected atoms. If False, use the existing fixed_xyz values.
- freeze(selection=None, freeze_at_current=True)[source]
Alias for fix(). Freeze atoms to use fixed xyz coordinates.
See fix() for full documentation.
- refine(selection=None, rebuild=True)[source]
Make atoms refinable by computing their positions from internal coordinates.
This unfreezes atoms, meaning their positions will be computed from bond lengths, angles, and torsions during forward pass.
- Parameters:
selection (torch.Tensor, slice, or None) – Boolean mask (shape n_atoms) or indices of atoms to make refinable. If None, makes all atoms refinable.
rebuild (bool, optional) – If True (default), rebuild internal coordinates from current fixed_xyz for the selected atoms. This ensures the internal coordinates match the current atom positions before unfreezing.
- unfreeze(selection=None, rebuild=True)[source]
Alias for refine(). Unfreeze atoms to use internal coordinates.
See refine() for full documentation.
- fix_all(freeze_at_current=True)[source]
Fix (freeze) all atoms.
- Parameters:
freeze_at_current (bool, optional) – If True (default), store current reconstructed xyz for all atoms.
- refine_all(rebuild=True)[source]
Make all atoms refinable.
- Parameters:
rebuild (bool, optional) – If True (default), rebuild internal coordinates from current fixed_xyz.
- forward_parallel()[source]
Reconstruct Cartesian xyz using parallel scan for backbone.
This is an optimized forward pass that: 1. Places backbone atoms using parallel prefix scan (O(log N) steps) 2. Places side chain atoms using depth iterations (O(max_sc_depth) steps)
For deep trees where backbone is long but side chains are short, this can be significantly faster than the standard forward().
- Returns:
Reconstructed Cartesian coordinates of shape (N, 3).
- Return type:
- class torchref.model.MixedModel(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]
Bases:
DeviceMixin,ModuleModel wrapper combining N ModelFT objects with learnable fractions.
Computes: F_mixed = Σ w_i * F_i where w_i are learnable weights constrained to sum to 1 via softmax.
This is useful for time-resolved crystallography where the crystal contains a mixture of conformational states (e.g., dark and light states) with unknown or refinable population fractions.
- Parameters:
models (List[ModelFT]) – List of ModelFT objects to combine. All models must have compatible cell parameters and space groups.
initial_fractions (List[float], optional) – Initial population fractions for each model. Must sum to 1.0. If None, equal fractions are used (1/N for each model).
frozen_fractions (bool, optional) – If True, fractions are not updated during optimization. Default is False.
verbose (int, optional) – Verbosity level. Default is 0.
- models
Constituent ModelFT objects (proper submodule registration).
- Type:
nn.ModuleList
- fraction_params
Raw parameters for fraction computation (softmax applied).
- Type:
nn.Parameter
Examples
Create a mixed model with two states:
model_dark = ModelFT().load_pdb('dark.pdb') model_light = ModelFT().load_pdb('light.pdb') # 70% dark, 30% light mixed = MixedModel([model_dark, model_light], initial_fractions=[0.7, 0.3]) # Compute mixed structure factors F_mixed = mixed(hkl) # Get current fractions print(mixed.fractions) # tensor([0.7, 0.3])
- __init__(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]
Initialize MixedModel.
- Parameters:
models (List[ModelFT]) – List of ModelFT objects to combine.
initial_fractions (List[float], optional) – Initial population fractions. Must sum to 1.0.
frozen_fractions (bool, optional) – If True, fractions are frozen. Default is False.
verbose (int, optional) – Verbosity level. Default is 0.
device (torch.device, optional) – Device to place the model and parameters on. If None, infers from the first model’s device.
- Raises:
ValueError – If models list is empty, fractions don’t match model count, fractions don’t sum to 1, or models have incompatible parameters.
- property fractions: Tensor
Get normalized population fractions.
- Returns:
Population fractions that sum to 1.0, shape (n_models,).
- Return type:
- property cell
Unit cell from first model (for compatibility).
- property spacegroup
Space group from first model (for compatibility).
- property device
Device from first model (for compatibility).
- property dtype_float
Float dtype from first model (for compatibility).
- property real_space_grid: Tensor | None
Real-space coordinate grid from first model (shared cell → same grid).
- property fft
SfFFT submodule from first model (for gridsize access).
- property map_symmetry
Map symmetry operator from first model.
- setup_grid(max_res=None, gridsize=None)[source]
Setup real-space grid on all constituent models.
All models share the same cell/spacegroup, so the grid is identical across all of them. This ensures each model’s SfFFT is ready for density map calculations.
- get_radius(min_radius_Angstrom=4.0)[source]
Get the radius in voxels for density calculation.
Delegates to first model (same grid → same voxel size).
- build_complete_map()[source]
Build the mixed electron density map as the weighted sum of constituent model density maps.
density_mixed = Σ w_i * density_i
Each constituent model builds its own density map on the shared grid, and the results are combined using the current population fractions.
- Returns:
Electron density map with shape (nx, ny, nz).
- Return type:
- freeze_fractions()[source]
Exclude fractions from optimization.
This prevents the population fractions from being updated during training while still allowing the constituent models to be refined.
- unfreeze_fractions()[source]
Include fractions in optimization.
This allows the population fractions to be updated during training.
- forward(hkl, recalc=False)[source]
Compute weighted sum of structure factors from all models.
f_mixed = Σ w_i * f_i
- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, force recalculation of structure factors. Default is False.
- Returns:
Mixed complex structure factors with shape (n_reflections,).
- Return type:
- get_individual_fcalc(hkl, recalc=True)[source]
Get structure factors from each model individually.
- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, force recalculation. Default is True.
- Returns:
List of structure factor tensors, one per model.
- Return type:
List[torch.Tensor]
- copy()[source]
Create a deep copy of the MixedModel.
- Returns:
A new MixedModel instance with copied models and parameters.
- Return type:
- write_ihm(filepath, state_names=None, group_name='ensemble', datasets=None)[source]
Write this MixedModel to an IHM mmCIF file.
Creates a single model group with the current population fractions over the constituent structural states.
Requires the optional
python-ihmdependency.- Parameters:
filepath (str) – Output file path.
state_names (list of str, optional) – Names for each constituent model / state. Default:
state_1,state_2, …group_name (str) – Name for the model group. Default
"ensemble".datasets (dict of str -> ReflectionData, optional) – Per-timepoint reflection data to embed in the CIF.
- get_vdw_radii()[source]
Get van der Waals radii from the first model.
- Returns:
Van der Waals radii tensor.
- Return type:
- class torchref.model.Model(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]
Bases:
DeviceMixin,DebugMixin,ModuleBase model class for atomic structure models using PyTorch.
This class provides the foundation for managing atomic structure data including coordinates, atomic displacement parameters (ADPs), and occupancies. It supports both empty initialization for state_dict loading and file-based initialization from PDB/CIF files.
- Parameters:
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.
verbose (int, optional) – Verbosity level for logging. Default is 1.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.
- xyz
Atomic coordinates tensor with shape (n_atoms, 3).
- Type:
- adp
Atomic displacement parameters (isotropic B-factors) with shape (n_atoms,).
- Type:
- u
Anisotropic displacement parameters with shape (n_atoms, 6).
- Type:
- occupancy
Atomic occupancies with values in [0, 1].
- Type:
- pdb
DataFrame containing atomic model data.
- Type:
- spacegroup
Space group object.
- Type:
gemmi.SpaceGroup
Examples
Empty initialization for state_dict loading:
model = Model() model.load_state_dict(torch.load('model.pt'))
File-based initialization:
model = Model() model.load_pdb('structure.pdb')
- __init__(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]
Initialize an empty Model shell.
Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().
- Parameters:
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.
verbose (int, optional) – Verbosity level for logging. Default is 1.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.
- property exclude_H_from_sf: bool
Whether to exclude hydrogen atoms from structure factor calculation.
When True, H atoms are excluded from
get_iso()/get_aniso()so they do not contribute to Fcalc. They still participate in geometry and VDW restraints. Default is False.
- property cell: Cell | None
Unit cell object with parameters [a, b, c, alpha, beta, gamma].
- Returns:
The unit cell object, or None if not set.
- Return type:
Cell or None
- property spacegroup: SpaceGroup | None
Space group object.
- Returns:
The space group object, or None if not set.
- Return type:
gemmi.SpaceGroup or None
- property symmetry: SpaceGroup | None
Symmetry operations handler for this space group.
Returns the same SpaceGroup object as self.spacegroup — the separate Symmetry wrapper was redundant since Symmetry is just an alias.
- Returns:
The space group object, or None if not set.
- Return type:
SpaceGroup or None
- property inv_fractional_matrix: Tensor
Fractionalization matrix B^-1 (Cartesian -> fractional).
Delegates to Cell for automatic caching and device/dtype handling.
- Returns:
Shape (3, 3) fractionalization matrix.
- Return type:
- property fractional_matrix: Tensor
Orthogonalization matrix B (fractional -> Cartesian).
Delegates to Cell for automatic caching and device/dtype handling.
- Returns:
Shape (3, 3) orthogonalization matrix.
- Return type:
- property recB: Tensor
Reciprocal basis matrix with [a*, b*, c*] as rows.
Delegates to Cell for automatic caching and device/dtype handling.
- Returns:
Shape (3, 3) matrix where rows are the reciprocal basis vectors.
- Return type:
- property Z: Tensor
Atomic numbers for all atoms.
- Returns:
Tensor of atomic numbers with shape (n_atoms,).
- Return type:
- get_P1_parameters_iso()[source]
Get model parameters transformed to P1 space for optimization.
This is useful for optimizers that do not handle symmetry directly or MD.
- Returns:
xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.
adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.
occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.
A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.
B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.
- Return type:
- get_MD_parameters()[source]
Get model parameters prepared for molecular dynamics simulation.
Returns all P1-expanded parameters plus atomic numbers for MD engines.
- Returns:
xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.
adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.
occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.
A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.
B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.
Z_p1 (torch.Tensor) – Atomic numbers expanded to P1 space.
- Return type:
- property parametrization
(A, B)}.
The parametrization is built lazily on first access.
- Returns:
Dictionary mapping element symbols to tuples of (A, B) tensors.
- Return type:
- Type:
ITC92 parametrization dictionary {element
- get_scattering_params_iso()[source]
Get ITC92 scattering parameters (A, B) for isotropic atoms.
- Returns:
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_iso_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_iso_atoms, 5).
- get_scattering_params_aniso()[source]
Get ITC92 scattering parameters (A, B) for anisotropic atoms.
- Returns:
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_aniso_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_aniso_atoms, 5).
- property restraints
Lazy restraints property.
The restraints are built on first access using the model’s pdb DataFrame and the CIF path set via set_restraints_cif().
- Returns:
The restraints object containing bond, angle, torsion, etc. restraints.
- Return type:
- bond_deviations()[source]
Compute bond length deviations using current xyz coordinates.
- Returns:
deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.
sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.
- angle_deviations()[source]
Compute angle deviations using current xyz coordinates.
- Returns:
deviations (torch.Tensor) – Calculated minus expected angles in radians.
sigmas (torch.Tensor) – Standard deviations in radians.
- torsion_deviations_with_sigmas()[source]
Compute torsion deviations (wrapped for periodicity) and sigmas.
- Returns:
deviations_rad (torch.Tensor) – Wrapped deviations in radians.
sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).
- property chain_sequences: List[Tuple[str, str]]
Per-chain amino acid sequences as single-letter codes.
Excludes HETATM records. Gaps in residue numbering are filled with
?. Non-standard residues are mapped toX.
- get_chain_residues()[source]
Per-chain residue names as 3-letter codes (for IHM/CIF writing).
Excludes HETATM records. Unlike
chain_sequences, returns the raw 3-letter codes without gap filling.
- get_vdw_radii()[source]
Get van der Waals radii for all atoms based on their elements.
Caches the result in self.vdw_radii for future calls.
- Returns:
Van der Waals radii for each atom with shape (n_atoms,).
- Return type:
- to(*args, **kwargs)[source]
Move Model and rebuild device-specific SF indices.
Delegates to
DeviceMixin, which walksself.__dict__(picking upself.cell,self.altloc_pairs,self._restraintsand all registered parameters / buffers), refreshes theself.devicetracker, and invalidates caches. Afterwards this override rebuilds the precomputed SF indices on the new device.
- copy()[source]
Create a deep copy of the Model.
Creates a complete independent copy including all registered buffers, module parameters, PDB DataFrame, and spacegroup information.
- Returns:
A new Model instance with copied data.
- Return type:
Examples
model = Model().load_pdb('structure.pdb') model_copy = model.copy() # model_copy is independent, changes won't affect model
- write_pdb(filename, metadata=None)[source]
Write model to PDB file with optional metadata header.
- Parameters:
filename (str) – Output PDB file path.
metadata (RefinementMetadata, optional) – Metadata to render as PDB header (REMARK 3, TITLE, etc.).
- write_cif(filename, metadata=None)[source]
Write model to mmCIF file with optional metadata.
- Parameters:
filename (str) – Output mmCIF file path.
metadata (RefinementMetadata, optional) – Metadata to include (refinement statistics, title, etc.).
- get_iso()[source]
Return per-atom parameters for the isotropic atom subset.
Selects atoms whose ADP is a single scalar
b(i.e. not anisotropic). The subset is defined by~self.aniso_flag— intersected withself._heavy_atom_maskwhen_exclude_H_from_sfis enabled — and is precomputed asself._iso_indicesat init / whenever the mask changes.- Returns:
xyz (torch.Tensor, shape
(n_iso, 3)) – Cartesian coordinates of the isotropic atoms (Å).adp (torch.Tensor, shape
(n_iso,)) – Isotropic B-factors (Ų).occupancy (torch.Tensor, shape
(n_iso,)) – Occupancies in[0, 1].
Notes
When every atom is isotropic and no H exclusion is active —
self._iso_covers_all is True, the common protein-refinement case — the per-atom indexing is skipped andself.xyz(),self.adp(),self.occupancy()are returned directly.Motivation:
self.xyz()[idx]is a no-op forward whenidx = arange(N), but its backward routes through PyTorch’saten::_index_put_impl_(accumulate=True), which performs acub::DeviceRadixSortOnesweepKerneloverlen(idx)indices followed by a deduplicated scatter (~50-150 µs/iter per gather on A100 / 1DAW). Skipping the gather avoids that cost.
- parameters_of_types(types)[source]
Return the leaf ``nn.Parameter``s for the named parameter types.
Used by refinement entry points (
refine_xyz,refine_adp, …) to construct an optimizer over only the leaves the caller intends to update.LossState.stepthen uses the optimizer’s param groups as intent and disablesrequires_gradon any other leaves the loss also touches.
- update_mask_from_selection(selection_string, target, mode='set', freeze=True)[source]
Update the refinable mask for a parameter using Phenix-style selection syntax.
This method updates the internal mask buffer (xyz_mask, adp_mask, u_mask, or occupancy_mask) based on the selection. The updated mask is NOT automatically applied to the parameter tensors - use apply_mask_to_parameter() to apply it.
- Parameters:
selection_string (str) – Phenix-style selection string (see parse_phenix_selection docs).
target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.
mode (str, optional) – How to combine with current mask: - ‘set’: Replace mask with selection (default) - ‘add’: Add selection to current mask - ‘remove’: Remove selection from current mask
freeze (bool, optional) – If True (default), selected atoms will be frozen (mask=False). If False, selected atoms will be unfrozen (mask=True).
- Raises:
ValueError – If target is not recognized or selection syntax is invalid.
Examples
# Freeze chain A coordinates model.update_mask_from_selection("chain A", "xyz", mode='set', freeze=True) model.apply_mask_to_parameter("xyz") # Unfreeze backbone atoms model.update_mask_from_selection("name CA or name C or name N", "xyz", freeze=False) model.apply_mask_to_parameter("xyz")
- apply_mask_to_parameter(target)[source]
Apply the current mask buffer to the parameter tensor.
Takes the current state of the mask buffer (xyz_mask, adp_mask, etc.) and applies it to the corresponding parameter tensor’s refinable mask.
- Parameters:
target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.
- Raises:
ValueError – If target is not recognized.
Examples
model.update_mask_from_selection("chain A", "xyz", freeze=True) model.apply_mask_to_parameter("xyz")
- freeze_selection(selection_string, targets='all')[source]
Freeze atoms matching a Phenix-style selection for specified parameters.
Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.
- Parameters:
Examples
# Freeze all parameters for chain A model.freeze_selection("chain A", targets='all') # Freeze only coordinates for residues 10-20 model.freeze_selection("resseq 10:20", targets='xyz')
- unfreeze_selection(selection_string, targets='all')[source]
Unfreeze atoms matching a Phenix-style selection for specified parameters.
Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.
- Parameters:
Examples
# Unfreeze all parameters for chain A model.unfreeze_selection("chain A", targets='all') # Unfreeze only coordinates for backbone atoms model.unfreeze_selection("name CA or name C or name N", targets='xyz')
- get_aniso()[source]
Return per-atom parameters for the anisotropic atom subset.
Selects atoms whose ADP is the 6-element anisotropic tensor
u = (u11, u22, u33, u12, u13, u23). The subset is defined byself.aniso_flag— intersected withself._heavy_atom_maskwhen_exclude_H_from_sfis enabled — and is precomputed asself._aniso_indicesat init / whenever the mask changes.- Returns:
xyz (torch.Tensor, shape
(n_aniso, 3)) – Cartesian coordinates of the anisotropic atoms (Å). Empty tensor when there are no anisotropic atoms.u (torch.Tensor, shape
(n_aniso, 6)) – Anisotropic U components (Ų) in the order(u11, u22, u33, u12, u13, u23). Empty whenn_aniso == 0.occupancy (torch.Tensor, shape
(n_aniso,)) – Occupancies in[0, 1]. Empty whenn_aniso == 0.
Notes
When there are no anisotropic atoms —
self._aniso_is_empty is True, the common protein-refinement case — three empty placeholder tensors are returned without calling the MixedTensors at all. This avoids both the wrapped forward.clone()and the slowaten::_index_put_impl_backward path that theself.xyz()[idx]gather would otherwise generate (seeget_iso()for the same rationale).
- parameters(recurse=True)[source]
Return an iterator over module parameters.
This is typically passed to an optimizer.
- Args:
- recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields:
Parameter: module parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- named_mixed_tensors()[source]
Iterate over all MixedTensor attributes with their names.
- Yields:
Tuple of (name, MixedTensor)
- register_alternative_conformations()[source]
Identify and register all alternative conformation groups in the structure.
For each residue that has alternative conformations (altloc A, B, C, etc.), this method identifies all atoms belonging to each conformation and stores their indices as tensors in a tuple.
The result is stored in self.altloc_pairs as a list of tuples, where each tuple contains tensors of atom indices for each alternative conformation.
Examples
For a residue with conformations A and B:
# Conformation A has atoms at indices [100, 101, 102, ...] # Conformation B has atoms at indices [110, 111, 112, ...] # Result: [(tensor([100, 101, 102, ...]), tensor([110, 111, 112, ...])), ...]
For a residue with conformations A, B, C:
# Result: [(tensor([200, 201, ...]), tensor([210, 211, ...]), tensor([220, 221, ...])), ...]
- shake_coords(stddev)[source]
Apply random Gaussian noise to atomic coordinates.
Perturbs the atomic coordinates by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.
- Parameters:
stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstroms.
- shake_adp(stddev)[source]
Apply random Gaussian noise to ADPs (atomic displacement parameters).
Perturbs the ADPs by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.
- Parameters:
stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstrom^2.
- generate_hydrogens(mon_lib_path=None)[source]
Generate hydrogen atoms for the current model using gemmi.
Places hydrogens at ideal geometry using the CCP4 monomer library and gemmi’s topology engine. Returns a new Model instance with hydrogens added; the original model is not modified.
- Parameters:
mon_lib_path (str, optional) – Path to CCP4 monomer library directory. If None, uses the monomer library bundled with torchref (covers standard amino acids and common small molecules).
- Returns:
A new Model instance with hydrogen atoms added (strip_H=False). Unknown residues are skipped silently.
- Return type:
Notes
Requires gemmi (already a torchref dependency). Heavy-atom coordinates from the current model state are used, so call this after any coordinate changes you want reflected in the H positions.
Examples
>>> model_no_h = Model().load_pdb('structure.pdb') >>> model_with_h = model_no_h.generate_hydrogens() >>> print(model_with_h.Z.shape) # more atoms than model_no_h
- strip_altlocs()[source]
Return a new model with alternate conformations removed.
For each residue that has multiple altlocs, the conformer with highest average occupancy is kept (ties broken alphabetically). The
altloccolumn is cleared to""in the returned model. The original model is not modified.
- strip_hydrogens()[source]
Return a new model with hydrogen atoms removed.
The returned model has consistent DataFrame and tensors (xyz, adp, occupancy) with H atoms excluded. The original model is not modified.
- Returns:
New model without hydrogen atoms.
- Return type:
- hydrogenate(verbose=0, optimize=False, lbfgs_steps=3, max_iter=20)[source]
Return a new model with hydrogen atoms placed via Kabsch alignment.
Uses torchref’s monomer library to identify missing H atoms, places them by SVD-aligning ideal monomer coordinates onto the current model coordinates, then corrects each H to sit at ideal bond length from its parent atom. The original model is not modified.
- Parameters:
verbose (int, optional) – Verbosity level (0=silent, 1=summary, 2=detailed). Default 0.
optimize (bool, optional) – If True, run a short LBFGS geometry optimization on H positions after placement. Default False (Kabsch placement only).
lbfgs_steps (int, optional) – Number of LBFGS outer steps (only when optimize=True). Default 3.
max_iter (int, optional) – Max line-search iterations per LBFGS step. Default 20.
- Returns:
New model with hydrogen atoms added. All parameters are unfrozen in the returned model.
- Return type:
- adp_loss()[source]
Compute the ADP regularization loss.
This loss encourages ADPs to have similar values across the structure, helping to prevent overfitting during refinement.
- Returns:
Scalar tensor representing the ADP loss.
- Return type:
- adp_nll_loss(target_log_std=0.2)[source]
Compute negative log-likelihood of ADPs assuming Gaussian distribution in log-space.
This regularization penalizes ADPs that deviate from a target distribution with a FIXED standard deviation (hyperparameter), avoiding circular dependency on the current distribution’s statistics.
The NLL for a Gaussian distribution in log-space is:
NLL = 0.5 * mean[(log_adp - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
Where mu is the mean of log-space ADPs (computed from current data) and sigma is the FIXED target standard deviation (hyperparameter).
- Parameters:
target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. - 0.1 = very tight (ADPs within ~10% of mean) - 0.2 = moderate spread (ADPs within ~20% of mean) [RECOMMENDED] - 0.3 = looser spread (ADPs within ~30% of mean)
- Returns:
Scalar tensor representing the NLL. Lower values indicate the distribution is closer to the target Gaussian with fixed sigma.
- Return type:
Examples
# During refinement structure_factor_loss = compute_structure_factor_loss() nll_reg = model.adp_nll_loss(target_log_std=0.2) total_loss = structure_factor_loss + 0.01 * nll_reg total_loss.backward()
Notes
Uses FIXED sigma (no circular dependency on current distribution). Smaller target_log_std = stronger regularization (tighter distribution).
- adp_nll_loss_per_atom(target_log_std=0.2)[source]
Compute per-atom negative log-likelihood for ADPs in log-space.
Returns the NLL contribution for each individual atom, useful for identifying outliers or applying atom-specific regularization weights.
The per-atom NLL is:
NLL_i = 0.5 * [(log_adp_i - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
- Parameters:
target_log_std (float, optional) – Fixed target standard deviation in log-space. Default is 0.2.
- Returns:
Tensor of shape (n_atoms,) with per-atom NLL values. Higher values indicate atoms farther from the mean.
- Return type:
Examples
# Get per-atom NLL atom_nll = model.adp_nll_loss_per_atom(target_log_std=0.2) # Identify outlier atoms (high NLL) threshold = atom_nll.mean() + 2 * atom_nll.std() outliers = atom_nll > threshold
- adp_kl_divergence_loss(target_log_std=0.2)[source]
Compute KL divergence between log ADP distribution and target Gaussian.
Measures how different the current log ADP distribution is from a target Gaussian distribution with the current mean of log ADPs and a fixed target standard deviation.
KL divergence formula for two Gaussians with same mean:
KL(q || p) = log(sigma_target/sigma_data) + sigma_data^2 / (2*sigma_target^2) - 0.5
- Parameters:
target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. Controls how tightly ADPs should cluster.
- Returns:
Scalar KL divergence value (always >= 0). 0 means distributions match perfectly. Higher values mean more deviation from target.
- Return type:
Examples
# Use in loss function loss = xray_loss + w_adp * model.adp_kl_divergence_loss(0.2)
Notes
Lower target_log_std = stronger regularization (tighter distribution). Mean is detached so it adapts to the natural scale of the data.
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the Model.
Includes all registered buffers, model parameters (xyz, b, u, occupancy), PDB DataFrame, and metadata (spacegroup, device, dtype, etc.).
- save_state(path)[source]
Save the complete state of the model to a file.
- Parameters:
path (str) – Path to save the state dictionary to.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]
Create a fully initialized Model from a state dictionary.
This is the recommended way to restore a Model from a saved state. Creates an instance with properly initialized submodules, then loads the state.
- Parameters:
state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
dtype_float (torch.dtype, optional) – Float dtype for tensors. Defaults to the configured dtypes.float.
- Returns:
Fully initialized instance with restored state.
- Return type:
- get_selection_mask(selection)[source]
Return a boolean mask for atoms matching a Phenix-style selection.
This is a convenience method that wraps parse_phenix_selection() to return a mask that can be used directly with MixedTensor.set() or other operations requiring atom selection.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates selected atoms.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid.
Examples
model = Model().load_pdb('structure.pdb') # Get mask for chain A mask = model.get_selection_mask("chain A") # Use mask to update coordinates new_coords = model.xyz()[mask] + translation model.xyz.set(new_coords, mask) # Get mask for backbone atoms backbone_mask = model.get_selection_mask("name CA or name C or name N or name O") # Complex selection with parentheses mask = model.get_selection_mask("chain A and (resname ALA or resname GLY)")
- select(selection)[source]
Return a new Model containing only atoms matching the Phenix-style selection.
Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid or no atoms are selected.
Examples
model = Model().load_pdb('structure.pdb') # Select chain A chain_a = model.select("chain A") # Select backbone atoms backbone = model.select("name CA or name C or name N or name O") # Select residues 10-50 of chain B region = model.select("chain B and resseq 10:50") # Select all except water no_water = model.select("not resname HOH") # Complex selection with parentheses complex_sel = model.select("chain A and (resname ALA or resname GLY)")
Notes
This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.
- xyz_fractional()[source]
Return atomic coordinates in fractional space.
Converts Cartesian coordinates to fractional coordinates using the inverse fractional matrix.
- Returns:
Tensor of shape (n_atoms, 3) with fractional coordinates.
- Return type:
- rotate(rotation_matrix, center=None)[source]
Apply rotation to atomic coordinates (in-place).
Rotates all atoms around a specified center point. The rotation is applied using the formula: xyz_new = R @ (xyz - center) + center
- Parameters:
rotation_matrix (torch.Tensor) – 3x3 rotation matrix. Should be orthogonal (R^T @ R = I).
center (torch.Tensor, optional) – Center of rotation with shape (3,). If None, uses the centroid of all atomic coordinates.
- Returns:
Self, for method chaining.
- Return type:
Examples
# Rotate 90 degrees around Z-axis import math angle = math.pi / 2 R = torch.tensor([ [math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1] ]) model.rotate(R) # Rotate around a specific point center = torch.tensor([10.0, 20.0, 30.0]) model.rotate(R, center=center)
- translate(translation, fractional=False)[source]
Apply translation to atomic coordinates (in-place).
Translates all atoms by a specified vector. The translation can be given in either Cartesian or fractional coordinates.
- Parameters:
translation (torch.Tensor) – Translation vector with shape (3,).
fractional (bool, optional) – If True, the translation is interpreted as fractional coordinates and converted to Cartesian before applying. Default is False (translation is in Cartesian Angstroms).
- Returns:
Self, for method chaining.
- Return type:
Examples
# Translate by 5 Angstroms along X model.translate(torch.tensor([5.0, 0.0, 0.0])) # Translate by half a unit cell along each axis model.translate(torch.tensor([0.5, 0.5, 0.5]), fractional=True)
- get_centroid()[source]
Compute the centroid (center of mass) of all atoms.
- Returns:
Centroid coordinates with shape (3,).
- Return type:
- use_internal_coordinates(n_aa_per_segment=5, bond_cutoff=2.0, cif_dict=None, requires_grad=True)[source]
Switch xyz to segmented internal coordinate parametrization.
Replaces the current xyz MixedTensor with a SegmentedInternalCoordinateTensor that parametrizes atomic positions using bond lengths, angles, torsion angles, and per-segment rigid body parameters. The molecule is broken into independent segments to avoid the “lever arm problem” where small torsion changes near the root cause large displacements at distant atoms.
- Parameters:
n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 5. - Smaller values (1-2): More segments, shallower trees, less lever arm - Larger values (5-10): Fewer segments, deeper trees, more lever arm
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0. Only used when cif_dict is not provided.
cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] DataFrame with ‘atom1’, ‘atom2’.
requires_grad (bool, optional) – Whether internal coordinate parameters should have gradients. Default is True.
- Returns:
Self, for method chaining.
- Return type:
Examples
model = Model() model.load_pdb('structure.pdb') model.use_internal_coordinates(n_aa_per_segment=3) # Now model.xyz() returns coordinates reconstructed from # segmented internal coordinates # Shake the structure using internal coordinates new_xyz = model.xyz.shake(magnitude=0.1) # Each segment has independent internal coordinates and # rigid body parameters (position + orientation)
Notes
After calling this method, model.xyz will be a SegmentedInternalCoordinateTensor instead of a MixedTensor. This provides: - Shallow spanning trees within segments (depth ~10-30 vs ~1000) - Independent segments that don’t propagate changes to distant atoms - Rigid body parameters (position + orientation) per segment - forward() / __call__(): Reconstruct Cartesian coordinates - shake(magnitude): Add noise to internal parameters - Gradient flow through all internal coordinate parameters
- class torchref.model.ModelFT(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]
Bases:
CachedForwardMixin,ModelModel subclass for Fourier Transform-based electron density and structure factor calculations.
ModelFT extends the base Model class with capabilities for computing electron density maps in real space and structure factors via FFT. Uses ITC92 parametrization for electron density calculations.
- Parameters:
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.
gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed from cell and max_res.
wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.
anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.
*args – Additional positional arguments passed to parent Model class.
**kwargs – Additional keyword arguments passed to parent Model class.
- gridsize
Grid dimensions (nx, ny, nz).
- Type:
- real_space_grid
Real-space coordinate grid with shape (nx, ny, nz, 3).
- Type:
- map
Computed electron density map.
- Type:
torch.Tensor or None
- map_symmetry
Symmetry operator for map calculations.
- Type:
Examples
Empty initialization for state_dict loading:
model = ModelFT() model.load_state_dict(torch.load('model.pt'))
File-based initialization:
model = ModelFT(max_res=1.5) model.load_pdb('structure.pdb')
- __init__(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]
Initialize an empty ModelFT shell.
Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().
- Parameters:
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 4.0.
gridsize (tuple of int, optional) – Explicit grid size tuple (nx, ny, nz). If None, computed automatically.
wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.
anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.
*args – Passed to parent Model class.
**kwargs – Passed to parent Model class.
- property cell
Unit cell object with parameters [a, b, c, alpha, beta, gamma].
- property spacegroup
Space group object.
- select(selection)[source]
Return a new Model containing only atoms matching the Phenix-style selection.
Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid or no atoms are selected.
Examples
model = Model().load_pdb('structure.pdb') # Select chain A chain_a = model.select("chain A") # Select backbone atoms backbone = model.select("name CA or name C or name N or name O") # Select residues 10-50 of chain B region = model.select("chain B and resseq 10:50") # Select all except water no_water = model.select("not resname HOH") # Complex selection with parentheses complex_sel = model.select("chain A and (resname ALA or resname GLY)")
Notes
This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.
- setup_gridsize(max_res=None)[source]
Compute optimal grid dimensions.
Delegates to FFT.compute_grid_size().
- Parameters:
max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.
- Returns:
Grid dimensions (nx, ny, nz) as int32 tensor.
- Return type:
- property A: Tensor
ITC92 A parameters (amplitudes) for all atoms.
- Returns:
A parameters with shape (n_atoms, 5).
- Return type:
- property B: Tensor
ITC92 B parameters (widths) for all atoms.
- Returns:
B parameters with shape (n_atoms, 5).
- Return type:
- get_iso()[source]
Get isotropic atoms with their ITC92 parameters.
- Returns:
xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).
adp (torch.Tensor) – Atomic displacement parameters (isotropic) with shape (n_atoms,).
occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- get_aniso()[source]
Get anisotropic atoms with their ITC92 parameters.
- Returns:
xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).
u (torch.Tensor) – Anisotropic U parameters with shape (n_atoms, 6).
occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- setup_grid(max_res=None, gridsize=None)[source]
Setup real-space grid for electron density calculation.
Delegates to FFT.setup_grid() using the stored cell and spacegroup.
- get_radius(min_radius_Angstrom=4.0)[source]
Get the radius in voxels used for density calculation around each atom.
- build_complete_map(radius=None, apply_symmetry=True)[source]
Build electron density map from all atoms.
Uses get_iso() and get_aniso() to get atom data and constructs the complete electron density map.
- Parameters:
- Returns:
Electron density map with symmetry applied if requested.
- Return type:
- build_initial_map(apply_symmetry=True)[source]
Build electron density map from atomic parameters.
Delegates to FFT.build_density_map() using the model’s stored parameters.
- Parameters:
apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.
- Returns:
Electron density map with shape (nx, ny, nz).
- Return type:
- save_map(filename)[source]
Save the electron density map to a CCP4 format file.
- Parameters:
filename (str) – Output filename for the map.
- Raises:
ValueError – If no map has been computed yet.
- rebuild_map(radius=None)[source]
Rebuild the density map from scratch.
Convenience method that clears and rebuilds everything.
- Parameters:
radius (int, optional) – Radius in voxels around each atom. If None, uses self.radius. If specified, overrides self.radius.
- Returns:
Rebuilt electron density map.
- Return type:
- get_structure_factor(hkl, recalc=False, apply_anomalous=True)[source]
Get structure factors for given hkl reflections.
Uses
CachedForwardMixinto cache the result and auto-invalidate when parameters change or a backward pass propagates through.- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, forces recalculation bypassing the cache. Default is False.
apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.
- Returns:
Complex structure factors with shape (n_reflections,).
- Return type:
Notes
- The complete scattering factor is:
f(s, λ) = f₀(s) + f’(λ) + i·f’’(λ)
where f₀ is the normal (Thomson) scattering factor computed via FFT, and f’/f’’ are the wavelength-dependent anomalous corrections.
Anomalous corrections are only computed for atoms where |f'| > anomalous_threshold or |f''| > anomalous_threshold.
- property fft
Access the SfFFT submodule.
- forward(hkl, apply_anomalous=True)[source]
Compute structure factors for given hkl.
This is called by the mixin’s
__call__which handles caching, backward-hook registration, and auto-invalidation.- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.
- Returns:
Calculated complex structure factors with shape (n_reflections,).
- Return type:
- copy(detach=True)[source]
Create a deep copy of the ModelFT.
Creates a complete independent copy including all Model base class data, FFT submodule state (gridsize, real_space_grid, voxel_size, map_symmetry), ITC92 parametrization, and scalar attributes. Cache is reset to empty.
- Parameters:
detach (bool, optional) – If True, the copy’s parameters will be detached from the computation graph (default: True).
- Returns:
A new ModelFT instance with copied data.
- Return type:
Examples
model = ModelFT().load_pdb('structure.pdb') model_copy = model.copy() # model_copy is independent, changes won't affect model
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the ModelFT.
Extends parent Model.state_dict() with FT-specific parameters including max_res, radius_angstrom. Grid state is handled by the FFT submodule.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]
Create a fully initialized ModelFT from a state dictionary.
This is the recommended way to restore a ModelFT from a saved state. Creates an instance with properly initialized submodules, then loads the state.
- Parameters:
state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
dtype_float (torch.dtype, optional) – Float dtype for tensors. Default is dtypes.float.
- Returns:
Fully initialized instance with restored state.
- Return type:
- class torchref.model.MixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]
Bases:
DeviceMixin,CachedForwardMixin,ModuleA wrapper class for tensors with mixed fixed and refinable elements.
Stores a mask indicating which elements can be refined and maintains both fixed and refinable components separately. The full tensor is reconstructed on-the-fly when accessed.
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.
device (torch.device, optional) – Device for the tensor. Default is same as initial_values.
name (str, optional) – Optional name for this parameter (useful for debugging/logging).
- refinable_mask
Boolean mask indicating refinable elements.
- Type:
- fixed_mask
Boolean mask indicating fixed elements (inverse of refinable_mask).
- Type:
- fixed_values
Buffer containing fixed values.
- Type:
- refinable_params
Parameter containing refinable values.
- Type:
nn.Parameter
Examples
Empty initialization for state_dict loading:
mixed = MixedTensor() mixed.load_state_dict(torch.load('mixed.pt'))
Full initialization with values:
mask = torch.zeros(100, dtype=torch.bool) mask[20:30] = True initial_values = torch.randn(100) mixed = MixedTensor(initial_values, refinable_mask=mask, requires_grad=True) optimizer = torch.optim.Adam([mixed.refinable_params], lr=0.01)
- __init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]
Initialize a MixedTensor.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.
device (torch.device, optional) – Device for the tensor. Default is same as initial_values.
name (str, optional) – Optional name for this parameter (useful for debugging/logging).
- forward()[source]
Reconstruct and return the full tensor.
Three fast paths, in priority order:
All atoms refinable — return
refinable_paramsdirectly (no clone, no scatter). Common in standard refinement where no atoms are frozen. Saves a full-tensor clone + anindex_put_per call and replaces theindex_putbackward (sort + atomic scatter) with a no-op identity.No refinable atoms — return
fixed_values.clone()(the clone preserves the caller-must-not-mutate contract even though the result is detached from autograd).Mixed — go through
_AssembleMixedTensor, whose backward is a singleindex_select(gather) instead of PyTorch’s defaultindex_put_backward (radix-sort + scatter).
- __getitem__(key)[source]
Get values at specified indices/mask from the full tensor.
- Parameters:
key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing
- Returns:
Selected values from the full tensor.
- Return type:
Examples
model.b[5] # Get B-factor for atom 5 model.b[5:10] # Get B-factors for atoms 5-9 model.b[mask] # Get B-factors where mask is True model.xyz[:, 0] # Get all x-coordinates
Notes
Subclasses may override _get_values() to customize value retrieval.
- __setitem__(key, value)[source]
Set values at specified indices/mask.
This method updates both fixed_values and refinable_params at the specified positions. Supports various indexing styles including slices, boolean masks, and integer indices.
- Parameters:
key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing
value (torch.Tensor, float, or int) – Values to assign. Can be: - Scalar: Broadcast to all selected positions - Tensor: Must match the shape of selected region
Examples
model.b[:] = 30.0 # Set all B-factors to 30 model.b[5:10] = 25.0 # Set B-factors 5-9 to 25 model.b[mask] = new_values # Set B-factors where mask is True model.xyz[mask] = new_coords # Set coordinates for masked atoms model.xyz[:, 0] += 1.0 # Shift all x-coordinates (read-modify-write)
Notes
This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.
Subclasses may override _set_values() to customize value handling (e.g., PositiveMixedTensor converts to log-space).
- set(values, mask)[source]
Set values at positions specified by a boolean mask.
Updates both fixed_values and refinable_params at the positions specified by the mask. This is useful for applying coordinate shifts, B-factor corrections, or any other updates to specific atoms.
- Parameters:
values (torch.Tensor) –
New values to assign. Shape must match: - For 1D tensors: (n_selected,) where n_selected = mask.sum() - For 2D tensors (e.g., xyz): (n_selected, d) where d is the
second dimension size (e.g., 3 for coordinates)
mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.
- Raises:
ValueError – If mask shape doesn’t match tensor’s first dimension, or if values shape doesn’t match the number of selected elements.
Examples
# Update coordinates for selected atoms mask = model.get_selection_mask("chain A") new_coords = original_coords[mask] + shift model.xyz.set(new_coords, mask) # Update B-factors for specific residues mask = model.get_selection_mask("resseq 10:20") new_b = torch.ones(mask.sum()) * 30.0 model.b.set(new_b, mask)
Notes
This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.
- property shape
Return the shape of the full tensor.
- property dtype
Return the dtype of the tensor.
- property device
Return the device of the tensor.
- update_fixed_values(new_values)[source]
Update the fixed values (does not affect refinable parameters).
- Parameters:
new_values (torch.Tensor) – New tensor values. Only fixed positions will be updated.
- Raises:
ValueError – If new_values shape doesn’t match tensor shape.
- update_refinable_mask(new_mask, reset_refinable=False)[source]
Update which elements are refinable.
This is an advanced operation that modifies the refinable/fixed split.
- Parameters:
new_mask (torch.Tensor) – New boolean mask indicating refinable elements.
reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.
- copy()[source]
Create a deep copy of this MixedTensor.
Creates a complete independent copy with all buffers and parameters. Alias for clone().
- Returns:
New MixedTensor instance with copied data.
- Return type:
- clip(min_value=None, max_value=None)[source]
Clip the full tensor values between min_value and max_value.
- refine(selection, reset_values=False)[source]
Make a selection of the tensor refinable.
- Parameters:
selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become refinable. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices
reset_values (bool, optional) – If True, reset the selected elements to their current fixed values before making them refinable. Default is False.
Examples
mixed.refine(slice(10, 20)) # Make elements 10-19 refinable mixed.refine(mask) # Make elements where mask is True refinable
- fix(selection, freeze_at_current=True)[source]
Make a selection of the tensor fixed (non-refinable).
- Parameters:
selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become fixed. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices
freeze_at_current (bool, optional) – If True (default), freeze the selected elements at their current values. If False, they revert to the original fixed values.
Examples
mixed.fix(slice(10, 20)) # Fix elements 10-19 mixed.fix(mask) # Fix elements where mask is True
- class torchref.model.PositiveMixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]
Bases:
MixedTensorA MixedTensor subclass ensuring all values are positive via log-space parametrization.
Useful for parameters that must be strictly positive (e.g., B-factors, scale factors, sigma values). Values are stored as logarithms internally and converted to normal space via exp() when accessed.
Reparametrization:
internal_value = log(desired_value) output_value = exp(internal_value)
This ensures output_value > 0 always, with smooth gradient flow.
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.
Examples
Empty initialization for state_dict loading:
b = PositiveMixedTensor() b.load_state_dict(torch.load('b_factors.pt'))
Full initialization with values:
initial_b = torch.tensor([20.0, 30.0, 15.0]) b = PositiveMixedTensor(initial_b) output = b() # Returns exp(log_b) = positive values assert (b() > 0).all()
- __init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]
Initialize a PositiveMixedTensor.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.
refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.
- Raises:
ValueError – If any initial values are not positive.
- forward()[source]
Return the full tensor in NORMAL space.
Applies exponential transformation to the log-space values.
- Returns:
Tensor with positive values.
- Return type:
- fix(mask, freeze_at_current=True)[source]
Fix (freeze) specific elements.
Converts current normal-space values to log space for storage.
- Parameters:
mask (torch.Tensor) – Boolean mask indicating which elements to fix.
freeze_at_current (bool, optional) – If True, freeze at current values. Default is True.
- refine(mask)[source]
Make specific elements refinable.
Preserves current log-space values.
- Parameters:
mask (torch.Tensor) – Boolean mask indicating which elements to make refinable.
- set(values, mask)[source]
Set values at positions specified by a boolean mask.
Values are provided in NORMAL space (e.g., actual B-factors) and automatically converted to log-space for internal storage.
- Parameters:
values (torch.Tensor) – New values to assign in NORMAL space (positive values). Shape must be (n_selected,) where n_selected = mask.sum().
mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.
- Raises:
ValueError – If mask shape doesn’t match tensor’s first dimension, if values shape doesn’t match the number of selected elements, or if any values are not positive.
Examples
# Update B-factors for selected atoms mask = model.get_selection_mask("name CA") new_b = torch.ones(mask.sum()) * 30.0 # Set CA B-factors to 30 model.b.set(new_b, mask)
Notes
This method modifies the tensor in-place. Values are automatically converted to log-space internally to maintain the positivity constraint.
- get_log_values()[source]
Return the internal log-space representation.
Useful for debugging or when direct access to the parametrization space is needed.
- Returns:
Tensor with log-space values.
- Return type:
- update_refinable_mask(new_mask, reset_refinable=False)[source]
Update which elements are refinable.
Properly handles log-space conversion.
- Parameters:
new_mask (torch.Tensor) – New boolean mask indicating refinable elements.
reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.
- class torchref.model.PassThroughTensor(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]
Bases:
DeviceMixin,ModuleA simple parameter wrapper that passes the parameter through unchanged.
Useful as a placeholder or for parameters that do not require any special handling.
- Parameters:
initial_values (torch.Tensor) – Initial tensor values.
requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.
dtype (torch.dtype, optional) – Data type of the tensor.
device (torch.device, optional) – Device to place the tensor on.
name (str, optional) – Optional name for the parameter.
- __init__(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]
Initialize the PassThroughTensor.
- Parameters:
initial_values (torch.Tensor) – Initial tensor values.
requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.
dtype (torch.dtype, optional) – Data type of the tensor.
device (torch.device, optional) – Device to place the tensor on.
name (str, optional) – Optional name for the parameter.
- class torchref.model.OccupancyTensor(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]
Bases:
MixedTensorA specialized MixedTensor for handling occupancy parameters in crystallographic refinement.
Handles specific constraints and requirements for occupancy including value bounds [0, 1] via sigmoid reparameterization, atom sharing, and alternative conformation sum-to-1 constraints.
Features: - Values bounded between 0 and 1 using sigmoid reparameterization - Atoms can share occupancies (e.g., all atoms in a residue) - Alternative conformations automatically sum to 1.0 via normalization - Memory-efficient collapsed storage (one parameter per sharing group) - Fully vectorized collapse/expand operations
- Parameters:
initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.
sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy.
altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Each tuple contains the atom indices for each conformation.
refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space).
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.
- expansion_mask
Maps atoms to collapsed indices.
- Type:
- collapse_counts
Count of atoms per collapsed index.
- Type:
Examples
sharing_groups = torch.tensor([0, 0, 1, 1, 2, 2]) occ = OccupancyTensor( initial_values=torch.tensor([1.0, 1.0, 0.7, 0.7, 0.3, 0.3]), sharing_groups=sharing_groups, altloc_groups=[([2, 3], [4, 5])], ) result = occ() # Atoms 2-3 and 4-5 will sum to 1.0
- __init__(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]
Initialize an OccupancyTensor with collapsed storage and altloc support.
If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.
sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy. Example: tensor([0, 0, 0, 1, 1, 2]) means atoms 0,1,2 share one occupancy, atoms 3,4 share another, and atom 5 is independent.
altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Example: [([10,11], [12,13])] means atoms 10,11 (conf A) and 12,13 (conf B) are altlocs that sum to 1.0.
refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space). If any atom in a group is refinable, the entire group becomes refinable.
requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for the tensor.
device (torch.device, optional) – Device for the tensor.
name (str, optional) – Optional name for this parameter.
use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.
- forward()[source]
Reconstruct full occupancy tensor with sigmoid and altloc constraints.
For alternative conformations, applies sigmoid then normalizes within each group to enforce sum-to-1 constraint.
- Returns:
Full occupancy tensor with values in [0, 1] and shape (n_atoms,).
- Return type:
- property shape
Return the shape of the FULL tensor (not collapsed).
- property collapsed_shape
Return the shape of the collapsed internal storage.
- clamp(min_value=0.0, max_value=1.0)[source]
Clamp occupancy values to specified range and return a new OccupancyTensor.
- Parameters:
- Returns:
New OccupancyTensor with clamped values.
- Return type:
- set_group_occupancy(group_idx, value)[source]
Set the occupancy for all atoms in a specific collapsed group.
- Parameters:
- Raises:
ValueError – If group_idx is out of range or value is not in [0, 1].
- get_group_occupancy(group_idx)[source]
Get the current occupancy value for a collapsed group.
- Parameters:
group_idx (int) – Collapsed index of the group.
- Returns:
Current occupancy value for the group.
- Return type:
- Raises:
ValueError – If group_idx is out of range.
- freeze(mask=None)[source]
Freeze occupancy parameters, making them non-refinable.
The mask is supplied in UNCOMPRESSED (full atom) form but freezing operates on the COMPRESSED data structure. This method handles the conversion.
- Parameters:
mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to freeze. If None, freeze all parameters. Shape must be (n_atoms,).
Notes
If ANY atom in a sharing group is frozen, the ENTIRE group is frozen because all atoms in a group share the same compressed parameter.
Examples
# Freeze atoms 0-10 (in full atom space) freeze_mask = torch.zeros(n_atoms, dtype=torch.bool) freeze_mask[0:11] = True occ.freeze(freeze_mask) # Freeze all atoms occ.freeze()
- unfreeze(mask=None)[source]
Unfreeze occupancy parameters, making them refinable.
The mask is supplied in UNCOMPRESSED (full atom) form but unfreezing operates on the COMPRESSED data structure. This method handles the conversion.
- Parameters:
mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to unfreeze. If None, unfreeze all parameters. Shape must be (n_atoms,).
Notes
If ANY atom in a sharing group is unfrozen, the ENTIRE group becomes refinable because all atoms in a group share the same compressed parameter.
Examples
# Unfreeze atoms 100-200 (in full atom space) unfreeze_mask = torch.zeros(n_atoms, dtype=torch.bool) unfreeze_mask[100:201] = True occ.unfreeze(unfreeze_mask) # Unfreeze all atoms occ.unfreeze()
- freeze_all()[source]
Freeze all occupancy parameters.
Convenience method equivalent to freeze(None).
- unfreeze_all()[source]
Unfreeze all occupancy parameters.
Convenience method equivalent to unfreeze(None).
- get_refinable_atoms()[source]
Get a boolean mask in FULL atom space indicating refinable atoms.
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is refinable (though it shares with others in its group).
- Return type:
- get_frozen_atoms()[source]
Get a boolean mask in FULL atom space indicating frozen atoms.
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is frozen.
- Return type:
- get_refinable_count()[source]
Get the number of refinable parameters in COMPRESSED space.
This is the number of refinable groups, not atoms. Use get_refinable_atoms().sum() to get the number of refinable atoms.
- Returns:
Number of refinable compressed parameters.
- Return type:
- get_fixed_count()[source]
Get the number of fixed parameters in COMPRESSED space.
This is the number of fixed groups, not atoms. Use get_frozen_atoms().sum() to get the number of frozen atoms.
- Returns:
Number of fixed compressed parameters.
- Return type:
- update_refinable_mask(new_mask, in_compressed_space=False)[source]
Directly update the refinable mask with a new mask.
Allows more direct control over which parameters are refinable, compared to freeze/unfreeze which modify the existing state.
- Parameters:
new_mask (torch.Tensor) – Boolean tensor indicating which parameters should be refinable. If in_compressed_space=False: shape (n_atoms,) in full atom space. If in_compressed_space=True: shape (n_groups,) in compressed space.
in_compressed_space (bool, optional) – If True, new_mask is in compressed space. If False (default), new_mask is in full atom space and will be collapsed.
Examples
Full atom space:
atom_mask = torch.zeros(n_atoms, dtype=torch.bool) atom_mask[:100] = True occ.update_refinable_mask(atom_mask, in_compressed_space=False)
Compressed space:
group_mask = torch.zeros(n_groups, dtype=torch.bool) group_mask[::2] = True occ.update_refinable_mask(group_mask, in_compressed_space=True)
- static from_residue_groups(initial_values, pdb_dataframe, refinable_mask=None, **kwargs)[source]
Create an OccupancyTensor where all atoms in each residue share occupancy.
Common use case where all atoms in a residue should have the same occupancy.
- Parameters:
initial_values (torch.Tensor) – Initial occupancy values for all atoms.
pdb_dataframe (pandas.DataFrame) – DataFrame with PDB data (must have ‘resname’, ‘resseq’, ‘chainid’).
refinable_mask (torch.Tensor, optional) – Mask for refinable atoms.
**kwargs – Additional arguments passed to OccupancyTensor constructor.
- Returns:
OccupancyTensor with residue-based sharing groups.
- Return type:
- class torchref.model.SegmentedInternalCoordinateTensor(initial_xyz, pdb, n_aa_per_segment=3, bond_cutoff=2.0, cif_dict=None, requires_grad=True, dtype=None, device=None)[source]
Bases:
DeviceMixin,CachedForwardMixin,ModuleParameter wrapper using segmented internal coordinates.
Stores: per-segment bond_lengths, angles, torsions, segment_positions, segment_orientations Reconstructs: Cartesian xyz on forward()
This provides a physically meaningful parametrization that avoids the lever arm problem by breaking the molecule into independent segments, each with shallow spanning trees and rigid body parameters.
- Parameters:
initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).
pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’.
n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 3.
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.
requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.
device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.
- bond_lengths
Bond length parameters in Angstroms.
- Type:
nn.Parameter
- angles
Angle parameters in radians.
- Type:
nn.Parameter
- torsions
Torsion angle parameters in radians.
- Type:
nn.Parameter
- segment_positions
Absolute positions of segment root atoms.
- Type:
nn.Parameter
- segment_orientations
ZYZ Euler angle orientations for each segment.
- Type:
nn.Parameter
- AA_NAMES = frozenset({'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'MSE', 'PHE', 'PRO', 'SEC', 'SER', 'THR', 'TRP', 'TYR', 'VAL'})
- __init__(initial_xyz, pdb, n_aa_per_segment=3, bond_cutoff=2.0, cif_dict=None, requires_grad=True, dtype=None, device=None)[source]
Initialize SegmentedInternalCoordinateTensor.
- Parameters:
initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).
pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’, ‘resname’.
n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 3.
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms (used as fallback). Default is 2.0.
cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] is a DataFrame with ‘atom1’ and ‘atom2’ columns.
requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.
device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.
- property dtype
Return the dtype of tensors.
- property device
Return the device of tensors.
- forward()[source]
Reconstruct Cartesian xyz from internal coordinates.
Uses fully vectorized operations for maximum performance.
- Returns:
Reconstructed Cartesian coordinates of shape (N, 3).
- Return type:
- shake(magnitude=0.1)[source]
Add Gaussian noise to internal parameters.
- Parameters:
magnitude (float, optional) – Standard deviation of Gaussian noise. Default is 0.1.
- Returns:
New Cartesian coordinates after perturbation.
- Return type:
- fix(selection=None, freeze_at_current=True)[source]
Fix (freeze) atoms to use fixed xyz coordinates.
- Parameters:
selection (torch.Tensor, slice, or None) – Boolean mask or indices of atoms to fix.
freeze_at_current (bool, optional) – If True, store current coordinates for selected atoms.
- refine(selection=None, rebuild=True)[source]
Make atoms refinable.
- Parameters:
selection (torch.Tensor, slice, or None) – Boolean mask or indices of atoms to make refinable.
rebuild (bool, optional) – If True, rebuild internal coordinates from fixed_xyz.
- class torchref.model.ClosedSegmentedInternalCoordinateTensor(initial_xyz, pdb, n_aa_per_segment=18, junction_size=3, bond_cutoff=2.0, cif_dict=None, prefer_loops=True, requires_grad=True, dtype=None, device=None)[source]
Bases:
DeviceMixin,CachedForwardMixin,ModuleParameter wrapper using segmented internal coordinates with chain closure.
Stores: per-segment bond_lengths, angles, torsions, segment_positions, segment_orientations. Junction backbone torsions are slave DOFs solved by Newton’s method with IFT gradients.
- Parameters:
initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).
pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’, ‘resname’.
n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 18.
junction_size (int, optional) – Number of residues per junction. Default is 3.
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.
cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type.
prefer_loops (bool, optional) – If True, slide junctions to prefer loop regions. Default is True.
requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.
dtype (torch.dtype, optional) – Data type for tensors.
device (torch.device, optional) – Device for tensors.
- __init__(initial_xyz, pdb, n_aa_per_segment=18, junction_size=3, bond_cutoff=2.0, cif_dict=None, prefer_loops=True, requires_grad=True, dtype=None, device=None)[source]
- property dtype
- property device
- forward()[source]
Reconstruct Cartesian xyz from internal coordinates with chain closure.
Steps: 1. Place segment atoms (existing NeRF pipeline) 2. Solve junction closures (Newton + IFT) 3. Place junction backbone atoms 4. Place junction sidechain atoms 5. Apply frozen overlay
- Returns:
Reconstructed Cartesian coordinates of shape (N, 3).
- Return type:
- shake(magnitude=0.1)[source]
Add Gaussian noise to internal parameters.
Perturbs torsions, bond lengths, and bond angles. Segment positions and orientations are NOT perturbed because random independent translation of segments creates gaps that exceed the junction chain’s reach. During optimization, the optimizer adjusts these rigid-body DOFs smoothly.
- class torchref.model.ModelCollection(base_models, dark_key='dark', verbose=0)[source]
Bases:
DeviceMixin,ModuleNamed dictionary of MixedModel instances at different timepoints.
All timepoint models share the same base structural models (ModelFT objects stored once in an nn.ModuleList). Each timepoint gets its own independent fraction parameters via _SharedMixedModel.
Keys should match DatasetCollection keys so that collection-aware targets can automatically pair datasets with models.
- Parameters:
Examples
models = ModelCollection([model_dark, model_light]) models.add_dark() # fractions=[1, 0] models.add_timepoint("1ps", [0.9, 0.1]) models.add_timepoint("5ps", [0.7, 0.3]) # Access mixed = models["1ps"] fcalc = mixed(hkl) print(mixed.fractions)
- add_timepoint(name, fractions=None, frozen_fractions=False)[source]
Add a timepoint with given initial fractions.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- add_dark(fractions=None)[source]
Add the dark / reference entry.
Default fractions: [1, 0, 0, …] (100 % ground state).
- Parameters:
fractions (List[float], optional) – Override dark fractions. Default is pure ground state.
- Returns:
Self, for method chaining.
- Return type:
- classmethod from_kinetics(base_models, occ_model, timepoint_names, dark_key='dark', verbose=0)[source]
Create a ModelCollection from a kinetics occupancy model.
- Parameters:
base_models (List[ModelFT]) – Shared structural models.
occ_model (occupancies_kinetics) – Kinetic occupancy model whose forward() returns shape [n_states, n_timepoints].
timepoint_names (List[str]) – Names for each timepoint column (excluding dark).
dark_key (str) – Key for the dark entry.
verbose (int) – Verbosity level.
- Return type:
- classmethod from_ihm(filepath, max_res=1.5, radius_angstrom=4.0, device=None, verbose=0)[source]
Load a ModelCollection from an IHM mmCIF file.
Requires the optional
python-ihmdependency.- Parameters:
filepath (str) – Path to IHM mmCIF file.
max_res (float) – Maximum resolution for FFT grid setup.
radius_angstrom (float) – Radius for electron density calculation.
device (torch.device, optional) – Device for model tensors.
verbose (int) – Verbosity level.
- Return type:
- write_ihm(filepath, mapping=None, datasets=None)[source]
Write this ModelCollection to IHM mmCIF format.
Requires the optional
python-ihmdependency.- Parameters:
filepath (str) – Output file path.
mapping (IHMEnsembleMapping, optional) – Mapping with metadata for round-tripping. If
None, a minimal mapping is created from the collection structure.datasets (dict of str -> ReflectionData, optional) – Per-timepoint reflection data to embed in the CIF. Each key should match a timepoint name.
- property dark_model: _SharedMixedModel
Shortcut for
self[dark_key].
- property base_models: ModuleList
The shared structural models (owned by this collection).
- property cell
- property spacegroup
- property device
Submodules
- torchref.model.closed_segmented_internal_coordinates module
ClosedSegmentedInternalCoordinateTensorClosedSegmentedInternalCoordinateTensor.__init__()ClosedSegmentedInternalCoordinateTensor.dtypeClosedSegmentedInternalCoordinateTensor.deviceClosedSegmentedInternalCoordinateTensor.forward()ClosedSegmentedInternalCoordinateTensor.shake()ClosedSegmentedInternalCoordinateTensor.fix()ClosedSegmentedInternalCoordinateTensor.freeze()ClosedSegmentedInternalCoordinateTensor.refine()ClosedSegmentedInternalCoordinateTensor.unfreeze()ClosedSegmentedInternalCoordinateTensor.fix_all()ClosedSegmentedInternalCoordinateTensor.freeze_all()ClosedSegmentedInternalCoordinateTensor.refine_all()ClosedSegmentedInternalCoordinateTensor.unfreeze_all()ClosedSegmentedInternalCoordinateTensor.n_refinableClosedSegmentedInternalCoordinateTensor.n_fixedClosedSegmentedInternalCoordinateTensor.closure_residualsClosedSegmentedInternalCoordinateTensor.max_closure_gap
- torchref.model.internal_coordinates module
InternalCoordinateTensorInternalCoordinateTensor.n_atomsInternalCoordinateTensor.n_chainsInternalCoordinateTensor.max_depthInternalCoordinateTensor.bond_lengthsInternalCoordinateTensor.anglesInternalCoordinateTensor.torsionsInternalCoordinateTensor.chain_positionsInternalCoordinateTensor.chain_orientationsInternalCoordinateTensor.__init__()InternalCoordinateTensor.dtypeInternalCoordinateTensor.deviceInternalCoordinateTensor.to()InternalCoordinateTensor.cuda()InternalCoordinateTensor.cpu()InternalCoordinateTensor.forward_slow()InternalCoordinateTensor.forward()InternalCoordinateTensor.shake()InternalCoordinateTensor.fix()InternalCoordinateTensor.freeze()InternalCoordinateTensor.refine()InternalCoordinateTensor.unfreeze()InternalCoordinateTensor.fix_all()InternalCoordinateTensor.freeze_all()InternalCoordinateTensor.refine_all()InternalCoordinateTensor.unfreeze_all()InternalCoordinateTensor.n_refinableInternalCoordinateTensor.n_fixedInternalCoordinateTensor.forward_parallel()
- torchref.model.mixed_model module
MixedModelMixedModel.modelsMixedModel.fraction_paramsMixedModel.__init__()MixedModel.fractionsMixedModel.cellMixedModel.spacegroupMixedModel.deviceMixedModel.dtype_floatMixedModel.real_space_gridMixedModel.fftMixedModel.gridsizeMixedModel.map_symmetryMixedModel.inv_fractional_matrixMixedModel.fractional_matrixMixedModel.setup_grid()MixedModel.get_radius()MixedModel.build_complete_map()MixedModel.freeze_fractions()MixedModel.unfreeze_fractions()MixedModel.forward()MixedModel.get_individual_fcalc()MixedModel.copy()MixedModel.__repr__()MixedModel.write_ihm()MixedModel.get_vdw_radii()MixedModel.xyz()
- torchref.model.model module
ModelModel.xyzModel.adpModel.uModel.occupancyModel.pdbModel.cellModel.spacegroupModel.symmetryModel.initializedModel.__init__()Model.__bool__()Model.exclude_H_from_sfModel.cellModel.spacegroupModel.symmetryModel.inv_fractional_matrixModel.fractional_matrixModel.recBModel.ZModel.get_P1_parameters_iso()Model.get_MD_parameters()Model.parametrizationModel.get_scattering_params_iso()Model.get_scattering_params_aniso()Model.set_restraints_cif()Model.restraintsModel.bond_deviations()Model.angle_deviations()Model.torsion_deviations_with_sigmas()Model.load()Model.load_pdb()Model.load_cif()Model.chain_sequencesModel.get_chain_residues()Model.update_pdb()Model.get_vdw_radii()Model.to()Model.copy()Model.write_pdb()Model.write_cif()Model.get_iso()Model.set_default_masks()Model.PARAM_TYPESModel.parameters_of_types()Model.freeze()Model.freeze_all()Model.unfreeze_all()Model.unfreeze()Model.update_mask_from_selection()Model.apply_mask_to_parameter()Model.freeze_selection()Model.unfreeze_selection()Model.get_aniso()Model.parameters()Model.named_mixed_tensors()Model.print_parameters_info()Model.register_alternative_conformations()Model.shake_coords()Model.shake_adp()Model.generate_hydrogens()Model.strip_altlocs()Model.strip_hydrogens()Model.hydrogenate()Model.adp_loss()Model.adp_nll_loss()Model.adp_nll_loss_per_atom()Model.adp_kl_divergence_loss()Model.state_dict()Model.save_state()Model.load_state()Model.create_from_state_dict()Model.get_selection_mask()Model.select()Model.xyz_fractional()Model.rotate()Model.translate()Model.get_centroid()Model.use_internal_coordinates()
- torchref.model.model_collection module
ModelCollectionModelCollection.__init__()ModelCollection.add_timepoint()ModelCollection.add_dark()ModelCollection.from_kinetics()ModelCollection.from_ihm()ModelCollection.write_ihm()ModelCollection.keys()ModelCollection.values()ModelCollection.items()ModelCollection.get()ModelCollection.dark_keyModelCollection.dark_modelModelCollection.base_modelsModelCollection.n_base_modelsModelCollection.timepoint_namesModelCollection.cellModelCollection.spacegroupModelCollection.deviceModelCollection.get_all_fractions()ModelCollection.get_fractions_matrix()ModelCollection.freeze_all_fractions()ModelCollection.unfreeze_all_fractions()ModelCollection.freeze_structures()ModelCollection.unfreeze_structures()ModelCollection.write_pdbs()
- torchref.model.model_ft module
ModelFTModelFT.max_resModelFT.radius_angstromModelFT.wavelengthModelFT.anomalous_thresholdModelFT.gridsizeModelFT.real_space_gridModelFT.mapModelFT.parametrizationModelFT.map_symmetryModelFT.__init__()ModelFT.cellModelFT.spacegroupModelFT.load_pdb()ModelFT.select()ModelFT.load_cif()ModelFT.setup_gridsize()ModelFT.AModelFT.BModelFT.gridsizeModelFT.real_space_gridModelFT.voxel_sizeModelFT.map_symmetryModelFT.get_iso()ModelFT.get_aniso()ModelFT.setup_grid()ModelFT.get_radius()ModelFT.build_complete_map()ModelFT.build_initial_map()ModelFT.save_map()ModelFT.get_map_statistics()ModelFT.rebuild_map()ModelFT.update_pdb()ModelFT.reset_cache()ModelFT.invalidate_cache()ModelFT.get_structure_factor()ModelFT.fftModelFT.forward()ModelFT.copy()ModelFT.state_dict()ModelFT.create_from_state_dict()
- torchref.model.parameter_wrappers module
MixedTensorMixedTensor.refinable_maskMixedTensor.fixed_maskMixedTensor.fixed_valuesMixedTensor.refinable_paramsMixedTensor.__init__()MixedTensor.forward()MixedTensor.__getitem__()MixedTensor.__setitem__()MixedTensor.set()MixedTensor.shapeMixedTensor.dtypeMixedTensor.deviceMixedTensor.get_refinable_count()MixedTensor.get_fixed_count()MixedTensor.update_fixed_values()MixedTensor.update_refinable_mask()MixedTensor.detach()MixedTensor.clone()MixedTensor.copy()MixedTensor.clip()MixedTensor.to()MixedTensor.refine()MixedTensor.fix()MixedTensor.refine_all()MixedTensor.fix_all()MixedTensor.nameMixedTensor.__str__()MixedTensor.parameters()
PositiveMixedTensorCholeskyMixedTensorOccupancyTensorOccupancyTensor.expansion_maskOccupancyTensor.linked_occ_sizesOccupancyTensor.collapse_countsOccupancyTensor.__init__()OccupancyTensor.forward()OccupancyTensor.shapeOccupancyTensor.collapsed_shapeOccupancyTensor.clamp()OccupancyTensor.set_group_occupancy()OccupancyTensor.get_group_occupancy()OccupancyTensor.freeze()OccupancyTensor.unfreeze()OccupancyTensor.freeze_all()OccupancyTensor.unfreeze_all()OccupancyTensor.get_refinable_atoms()OccupancyTensor.get_frozen_atoms()OccupancyTensor.get_refinable_count()OccupancyTensor.get_fixed_count()OccupancyTensor.update_refinable_mask()OccupancyTensor.from_residue_groups()OccupancyTensor.copy()
PassThroughTensor
- torchref.model.segmented_internal_coordinates module
SegmentedInternalCoordinateTensorSegmentedInternalCoordinateTensor.n_atomsSegmentedInternalCoordinateTensor.n_segmentsSegmentedInternalCoordinateTensor.max_depthSegmentedInternalCoordinateTensor.bond_lengthsSegmentedInternalCoordinateTensor.anglesSegmentedInternalCoordinateTensor.torsionsSegmentedInternalCoordinateTensor.segment_positionsSegmentedInternalCoordinateTensor.segment_orientationsSegmentedInternalCoordinateTensor.AA_NAMESSegmentedInternalCoordinateTensor.__init__()SegmentedInternalCoordinateTensor.dtypeSegmentedInternalCoordinateTensor.deviceSegmentedInternalCoordinateTensor.forward()SegmentedInternalCoordinateTensor.shake()SegmentedInternalCoordinateTensor.fix()SegmentedInternalCoordinateTensor.freeze()SegmentedInternalCoordinateTensor.refine()SegmentedInternalCoordinateTensor.unfreeze()SegmentedInternalCoordinateTensor.fix_all()SegmentedInternalCoordinateTensor.freeze_all()SegmentedInternalCoordinateTensor.refine_all()SegmentedInternalCoordinateTensor.unfreeze_all()SegmentedInternalCoordinateTensor.n_refinableSegmentedInternalCoordinateTensor.n_fixed
- torchref.model.sf_ds module
- torchref.model.sf_fft module
SfFFTSfFFT.cellSfFFT.spacegroupSfFFT.symmetrySfFFT.max_resSfFFT.radius_angstromSfFFT.gridsizeSfFFT.real_space_gridSfFFT.voxel_sizeSfFFT.map_symmetrySfFFT.__init__()SfFFT.cellSfFFT.spacegroupSfFFT.symmetrySfFFT.fractional_matrixSfFFT.inv_fractional_matrixSfFFT.set_cell_and_spacegroup()SfFFT.compute_optimal_gridsize()SfFFT.compute_real_space_grid()SfFFT.setup_grid()SfFFT.build_density_map()SfFFT.map_to_structure_factors()SfFFT.compute_structure_factors()SfFFT.reset_cache()SfFFT.copy()
FFT()