torchref.model package

Atomic structure model module for TorchRef.

This module provides PyTorch nn.Module-based representations of crystallographic atomic models, including coordinates, B-factors, occupancies, and anisotropic displacement parameters.

Classes

Model

Base atomic model storing xyz coordinates, B-factors, occupancies.

ModelFT

Fourier Transform model for FFT-based structure factor calculation.

SfFFT

Structure Factor calculator using FFT (Fast Fourier Transform).

SfDS

Structure Factor calculator using Direct Summation.

FFT

Backward compatibility alias for SfFFT.

MixedTensor

Hybrid tensor allowing partial freezing of parameters.

PositiveMixedTensor

MixedTensor with positivity constraint.

PassThroughTensor

Direct parameter access wrapper.

OccupancyTensor

Tensor constrained to [0, 1] range for occupancies.

Example

from torchref.model import Model, ModelFT, MixedTensor, SfFFT, SfDS

# Load model from PDB
model = Model()
model.load_pdb('structure.pdb')

# Access coordinates and B-factors
xyz = model.xyz  # (N, 3) tensor
b = model.b      # (N,) tensor

# Use ModelFT for FFT-based structure factors
model_ft = ModelFT(data, device='cuda')
F_calc = model_ft.get_F_calc()

# Use SfFFT standalone for custom workflows
sf_fft = SfFFT(max_res=1.5)
sf_fft.setup_grid(cell, spacegroup)
sf = sf_fft.map_to_structure_factors(density_map, hkl)

# Use SfDS for direct summation
sf_ds = SfDS(cell, spacegroup)
sf, _ = sf_ds.compute_structure_factors(hkl, xyz, adp, occ, A, B)
torchref.model.FFT(*args, **kwargs)[source]

Deprecated: use SfFFT instead.

class torchref.model.SfFFT(cell=None, spacegroup=None, max_res=1.5, radius_angstrom=3.0, dtype_float=torch.float32, device=None, verbose=0, use_late_symmetry=True)[source]

Bases: DeviceMixin, Module

Structure Factor calculator using FFT (Fast Fourier Transform).

This module encapsulates all FFT-related functionality for computing electron density maps and structure factors. It is initialized with a Cell and optionally a SpaceGroup, which are used for grid calculations.

Parameters:
  • cell (Cell) – Unit cell object containing cell parameters.

  • spacegroup (SpaceGroupLike, optional) – Space group specification (string, int, or gemmi.SpaceGroup). If None, defaults to P1.

  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.

  • radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.

  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level for logging. Default is 0.

cell

Unit cell object.

Type:

Cell

spacegroup

Space group object (SpaceGroup nn.Module with matrices and translations).

Type:

SpaceGroup

symmetry

Alias for spacegroup (backward compatibility).

Type:

SpaceGroup

max_res

Maximum resolution for grid spacing.

Type:

float

radius_angstrom

Radius for density calculation around each atom.

Type:

float

gridsize

Grid dimensions (nx, ny, nz) when grid is set up.

Type:

torch.Tensor or None

real_space_grid

Real-space coordinate grid with shape (nx, ny, nz, 3).

Type:

torch.Tensor or None

voxel_size

Voxel dimensions.

Type:

torch.Tensor or None

map_symmetry

Symmetry operator for map calculations.

Type:

MapSymmetry or None

Examples

Standalone usage:

from torchref.symmetry import Cell
cell = Cell([50, 60, 70, 90, 90, 90])
sf_fft = SfFFT(cell, spacegroup='P212121', max_res=1.5)
sf_fft.setup_grid()
density_map = sf_fft.build_density_map(xyz, b, occ, A, B, inv_frac, frac)
sf = sf_fft.map_to_structure_factors(density_map, hkl)

With ModelFT (composition):

model = ModelFT()
model.load_pdb('structure.pdb')
sf = model.get_structure_factor(hkl)  # Uses internal SfFFT instance
__init__(cell=None, spacegroup=None, max_res=1.5, radius_angstrom=3.0, dtype_float=torch.float32, device=None, verbose=0, use_late_symmetry=True)[source]

Initialize the SfFFT module with cell and spacegroup.

Parameters:
  • cell (Cell, optional) – Unit cell object. If None, must be set later via set_cell().

  • spacegroup (SpaceGroupLike, optional) – Space group specification. If None, defaults to P1.

  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.5.

  • radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 3.0.

  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.

  • device (torch.device, optional) – Computation device. Default is None (uses cell’s device). If Cell is also None, defaults to CPU.

  • verbose (int, optional) – Verbosity level for logging. Default is 0.

  • use_late_symmetry (bool, optional) – If True (default), apply symmetry in reciprocal space after FFT (“late symmetry”) for faster structure factor calculation (~5x speedup). If False, apply symmetry to density map before FFT (“early symmetry”).

property cell: Cell | None

Unit cell object.

property spacegroup: SpaceGroup | None

Space group object (SpaceGroup nn.Module).

property symmetry: SpaceGroup | None

Symmetry operations handler (alias for spacegroup).

property fractional_matrix: Tensor | None

Get fractionalization matrix from cell, on this module’s device/dtype.

property inv_fractional_matrix: Tensor | None

Get orthogonalization matrix from cell, on this module’s device/dtype.

set_cell_and_spacegroup(cell, spacegroup=None)[source]

Set cell and spacegroup for this SfFFT instance.

Parameters:
  • cell (Cell) – Unit cell object.

  • spacegroup (SpaceGroupLike, optional) – Space group specification.

compute_optimal_gridsize(max_res=None)[source]

Compute optimal grid dimensions using the stored cell and spacegroup.

Uses Cell.compute_grid_size() for base calculation and Symmetry.suggest_grid_size() for symmetry optimization.

Parameters:

max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.

Returns:

Optimal grid dimensions (nx, ny, nz).

Return type:

tuple of int

Raises:

RuntimeError – If cell has not been set.

static compute_real_space_grid(fractional_matrix, gridsize, device=device(type='cpu'))[source]

Generate the real-space coordinate grid.

Parameters:
  • cell_data (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • gridsize (torch.Tensor) – Grid dimensions (nx, ny, nz).

  • device (torch.device, optional) – Target device. Default is CPU.

Returns:

Real-space grid with shape (nx, ny, nz, 3).

Return type:

torch.Tensor

setup_grid(gridsize=None, max_res=None)[source]

Setup the real-space grid for electron density calculation.

This method initializes and stores the grid state for subsequent density map calculations. Uses the stored cell and spacegroup.

Parameters:
  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed automatically using Cell.compute_grid_size() and Symmetry.suggest_grid_size().

  • max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.

Raises:

RuntimeError – If cell has not been set.

build_density_map(xyz_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, apply_symmetry=True)[source]

Build electron density map from atomic parameters.

This method requires setup_grid() to have been called first.

Parameters:
  • xyz_iso (torch.Tensor) – Isotropic atom coordinates with shape (n_iso, 3).

  • adp_iso (torch.Tensor) – Isotropic ADPs (atomic displacement parameters) with shape (n_iso,).

  • occ_iso (torch.Tensor) – Isotropic occupancies with shape (n_iso,).

  • A_iso (torch.Tensor) – ITC92 A parameters for isotropic atoms with shape (n_iso, 5).

  • B_iso (torch.Tensor) – ITC92 B parameters for isotropic atoms with shape (n_iso, 5).

  • xyz_aniso (torch.Tensor, optional) – Anisotropic atom coordinates with shape (n_aniso, 3).

  • u_aniso (torch.Tensor, optional) – Anisotropic U parameters with shape (n_aniso, 6).

  • occ_aniso (torch.Tensor, optional) – Anisotropic occupancies with shape (n_aniso,).

  • A_aniso (torch.Tensor, optional) – ITC92 A parameters for anisotropic atoms with shape (n_aniso, 5).

  • B_aniso (torch.Tensor, optional) – ITC92 B parameters for anisotropic atoms with shape (n_aniso, 5).

  • apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.

Returns:

Electron density map with shape (nx, ny, nz).

Return type:

torch.Tensor

Raises:

RuntimeError – If setup_grid() has not been called.

map_to_structure_factors(density_map, hkl, apply_symmetry=True)[source]

Convert density map to structure factors via FFT.

Parameters:
  • density_map (torch.Tensor) – Electron density map with shape (nx, ny, nz). If apply_symmetry=True, this should be a P1 density map.

  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • apply_symmetry (bool, optional) – If True and late symmetry is enabled/compatible, apply symmetry in reciprocal space. Default is False (assume map already has symmetry applied or use early symmetry path).

Returns:

Complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

compute_structure_factors(hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, apply_symmetry=True)[source]

Compute structure factors from atomic parameters (end-to-end).

This is a convenience method that builds the density map and computes structure factors in one call.

When use_late_symmetry=True (default) and the grid is compatible, symmetry is applied in reciprocal space after FFT for ~5x speedup. Otherwise, symmetry is applied to the density map before FFT.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • xyz_iso (torch.Tensor) – Isotropic atom coordinates with shape (n_iso, 3).

  • adp_iso (torch.Tensor) – Isotropic ADPs (atomic displacement parameters) with shape (n_iso,).

  • occ_iso (torch.Tensor) – Isotropic occupancies with shape (n_iso,).

  • A_iso (torch.Tensor) – ITC92 A parameters for isotropic atoms.

  • B_iso (torch.Tensor) – ITC92 B parameters for isotropic atoms.

  • xyz_aniso (torch.Tensor, optional) – Anisotropic atom coordinates.

  • u_aniso (torch.Tensor, optional) – Anisotropic U parameters.

  • occ_aniso (torch.Tensor, optional) – Anisotropic occupancies.

  • A_aniso (torch.Tensor, optional) – ITC92 A parameters for anisotropic atoms.

  • B_aniso (torch.Tensor, optional) – ITC92 B parameters for anisotropic atoms.

  • apply_symmetry (bool, optional) – If True, apply crystallographic symmetry. Default is True.

Returns:

  • sf (torch.Tensor) – Complex structure factors with shape (n_reflections,).

  • density_map (torch.Tensor) – Electron density map with shape (nx, ny, nz). Note: When using late symmetry, this is the P1 map (without symmetry).

Return type:

Tuple[Tensor, Tensor]

reset_cache()[source]

Drop the cached symmetry extractor; recomputed on next use.

copy()[source]

Create a deep copy of this SfFFT module.

Returns:

A new SfFFT instance with cloned cell, spacegroup, and buffers.

Return type:

SfFFT

class torchref.model.SfDS(cell=None, spacegroup=None, dtype_float=torch.float32, device=device(type='cpu'), verbose=0, max_memory_gb=2.0)[source]

Bases: DeviceMixin, Module

Structure Factor calculator using Direct Summation.

This module computes structure factors by directly summing atomic contributions without building an intermediate electron density map. It is initialized with a Cell and optionally a SpaceGroup.

Includes automatic batching to handle memory constraints for large structures or high-resolution data.

Parameters:
  • cell (Cell, optional) – Unit cell object containing cell parameters.

  • spacegroup (SpaceGroupLike, optional) – Space group specification (string, int, or gemmi.SpaceGroup). If None, defaults to P1.

  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level for logging. Default is 0.

  • max_memory_gb (float, optional) – Maximum memory to use for intermediate tensors in GB. Default is 2.0. Set to None to disable batching.

cell

Unit cell object.

Type:

Cell

spacegroup

Space group object (SpaceGroup nn.Module with matrices and translations).

Type:

SpaceGroup

Examples

Standalone usage:

from torchref.symmetry import Cell
cell = Cell([50, 60, 70, 90, 90, 90])
sf_ds = SfDS(cell, spacegroup='P212121')
sf, _ = sf_ds.compute_structure_factors(
    hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso
)

With memory limit for large structures:

sf_ds = SfDS(cell, spacegroup='P212121', max_memory_gb=4.0)
sf, _ = sf_ds.compute_structure_factors(...)  # Auto-batches if needed

Notes

Key differences from SfFFT: - No grid setup required - No build_density_map() or map_to_structure_factors() methods - Computes scattering factors internally from A/B ITC92 coefficients - Returns (sf, None) instead of (sf, density_map) - Automatic batching for memory management

__init__(cell=None, spacegroup=None, dtype_float=torch.float32, device=device(type='cpu'), verbose=0, max_memory_gb=2.0)[source]

Initialize the SfDS module with cell and spacegroup.

Parameters:
  • cell (Cell, optional) – Unit cell object. If None, must be set later.

  • spacegroup (SpaceGroupLike, optional) – Space group specification. If None, defaults to P1.

  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Default is dtypes.float.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level for logging. Default is 0.

  • max_memory_gb (float, optional) – Maximum memory for intermediate tensors in GB. Default is 2.0.

property cell: Cell | None

Unit cell object.

property spacegroup: SpaceGroup | None

Space group object (SpaceGroup nn.Module).

property fractional_matrix: Tensor | None

Get fractionalization matrix from cell.

property inv_fractional_matrix: Tensor | None

Get orthogonalization matrix from cell.

set_cell_and_spacegroup(cell, spacegroup=None)[source]

Set cell and spacegroup for this SfDS instance.

Parameters:
  • cell (Cell) – Unit cell object.

  • spacegroup (SpaceGroupLike, optional) – Space group specification.

compute_structure_factors(hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, apply_symmetry=True)[source]

Compute structure factors from atomic parameters using direct summation.

Uses “late symmetry” approach (same as SfFFT): first computes P1 structure factors at symmetry-equivalent HKLs, then combines them with phase shifts.

The symmetry formula is:

F_sym(h) = Σ_ops exp(2πi h.t) * F_P1(R^T @ h)

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • xyz_iso (torch.Tensor) – Isotropic atom coordinates (Cartesian) with shape (n_iso, 3).

  • adp_iso (torch.Tensor) – Isotropic ADPs (atomic displacement parameters) with shape (n_iso,).

  • occ_iso (torch.Tensor) – Isotropic occupancies with shape (n_iso,).

  • A_iso (torch.Tensor) – ITC92 A parameters for isotropic atoms with shape (n_iso, 5).

  • B_iso (torch.Tensor) – ITC92 B parameters for isotropic atoms with shape (n_iso, 5).

  • xyz_aniso (torch.Tensor, optional) – Anisotropic atom coordinates (Cartesian) with shape (n_aniso, 3).

  • u_aniso (torch.Tensor, optional) – Anisotropic U parameters with shape (n_aniso, 6).

  • occ_aniso (torch.Tensor, optional) – Anisotropic occupancies with shape (n_aniso,).

  • A_aniso (torch.Tensor, optional) – ITC92 A parameters for anisotropic atoms with shape (n_aniso, 5).

  • B_aniso (torch.Tensor, optional) – ITC92 B parameters for anisotropic atoms with shape (n_aniso, 5).

  • apply_symmetry (bool, optional) – If True, apply crystallographic symmetry. Default is True.

Returns:

  • sf (torch.Tensor) – Complex structure factors with shape (n_reflections,).

  • None – Second return value is None (for API compatibility with SfFFT).

Return type:

Tuple[Tensor, None]

reset_cache()[source]

Drop the cached reciprocal-basis matrix; recomputed on next use.

copy()[source]

Create a deep copy of this SfDS module.

Returns:

A new SfDS instance with cloned cell and spacegroup.

Return type:

SfDS

class torchref.model.InternalCoordinateTensor(initial_xyz, bond_cutoff=2.0, requires_grad=True, dtype=None, device=None)[source]

Bases: DeviceMixin, Module

Parameter wrapper using internal coordinates (Z-matrix style).

Stores: bond_lengths, angles, torsions, chain_positions, chain_orientations Reconstructs: Cartesian xyz on forward()

This provides a physically meaningful parametrization of atomic coordinates where perturbations correspond to changes in bond lengths, angles, and torsion angles rather than arbitrary Cartesian displacements.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.

  • device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.

n_atoms

Number of atoms.

Type:

int

n_chains

Number of disconnected chains.

Type:

int

max_depth

Maximum depth in the spanning tree.

Type:

int

bond_lengths

Bond length parameters in Angstroms.

Type:

nn.Parameter

angles

Angle parameters in radians.

Type:

nn.Parameter

torsions

Torsion angle parameters in radians.

Type:

nn.Parameter

chain_positions

Absolute positions of chain root atoms.

Type:

nn.Parameter

chain_orientations

Axis-angle orientations for each chain.

Type:

nn.Parameter

__init__(initial_xyz, bond_cutoff=2.0, requires_grad=True, dtype=None, device=None)[source]

Initialize InternalCoordinateTensor from Cartesian coordinates.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.

  • device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.

property dtype

Return the dtype of tensors.

property device

Logical device — where forward()’s result is delivered.

Internal parameters/buffers stay on CPU regardless; this is the device requested by the caller (e.g. via .to('mps')) and is the device the forward output is migrated to.

to(*args, **kwargs)[source]

Update output device and optionally cast dtype.

Unlike DeviceMixin.to, this does not move internal parameters/buffers to device — they stay on CPU to avoid the per-op dispatch overhead of MPS/CUDA on the sequential spanning-tree + parallel-scan code. The device argument only updates _output_device; dtype still propagates normally and recasts all CPU tensors.

cuda(device=None)[source]

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Args:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

cpu()[source]

Move all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns:

Module: self

forward_slow()[source]

Reconstruct Cartesian xyz from internal coordinates.

Fully vectorized - processes each depth level in parallel. Only log(max_depth) sequential steps required.

Returns:

Reconstructed Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

forward()[source]

Reconstruct Cartesian xyz from internal coordinates.

Uses optimized parallel scan method for efficiency.

Returns:

Reconstructed Cartesian coordinates of shape (N, 3), on the configured output device.

Return type:

torch.Tensor

shake(magnitude=0.1)[source]

Add Gaussian noise to internal parameters (fully vectorized).

All operations are batched tensor ops - no loops.

Parameters:

magnitude (float, optional) – Standard deviation of Gaussian noise. Default is 0.1. For bond lengths, this is in Angstroms. For angles and torsions, this is in radians.

Returns:

New Cartesian coordinates after perturbation.

Return type:

torch.Tensor

fix(selection=None, freeze_at_current=True)[source]

Fix (freeze) atoms to use fixed xyz coordinates instead of internal coordinates.

Fixed atoms will not be updated during reconstruction from internal coordinates. Their positions will remain at the stored fixed_xyz values.

Parameters:
  • selection (torch.Tensor, slice, or None) – Boolean mask (shape n_atoms) or indices of atoms to fix. If None, fixes all atoms.

  • freeze_at_current (bool, optional) – If True (default), store current reconstructed xyz for the selected atoms. If False, use the existing fixed_xyz values.

freeze(selection=None, freeze_at_current=True)[source]

Alias for fix(). Freeze atoms to use fixed xyz coordinates.

See fix() for full documentation.

refine(selection=None, rebuild=True)[source]

Make atoms refinable by computing their positions from internal coordinates.

This unfreezes atoms, meaning their positions will be computed from bond lengths, angles, and torsions during forward pass.

Parameters:
  • selection (torch.Tensor, slice, or None) – Boolean mask (shape n_atoms) or indices of atoms to make refinable. If None, makes all atoms refinable.

  • rebuild (bool, optional) – If True (default), rebuild internal coordinates from current fixed_xyz for the selected atoms. This ensures the internal coordinates match the current atom positions before unfreezing.

unfreeze(selection=None, rebuild=True)[source]

Alias for refine(). Unfreeze atoms to use internal coordinates.

See refine() for full documentation.

fix_all(freeze_at_current=True)[source]

Fix (freeze) all atoms.

Parameters:

freeze_at_current (bool, optional) – If True (default), store current reconstructed xyz for all atoms.

freeze_all(freeze_at_current=True)[source]

Alias for fix_all(). Freeze all atoms.

refine_all(rebuild=True)[source]

Make all atoms refinable.

Parameters:

rebuild (bool, optional) – If True (default), rebuild internal coordinates from current fixed_xyz.

unfreeze_all(rebuild=True)[source]

Alias for refine_all(). Unfreeze all atoms.

property n_refinable: int

Return the number of refinable (unfrozen) atoms.

property n_fixed: int

Return the number of fixed (frozen) atoms.

forward_parallel()[source]

Reconstruct Cartesian xyz using parallel scan for backbone.

This is an optimized forward pass that: 1. Places backbone atoms using parallel prefix scan (O(log N) steps) 2. Places side chain atoms using depth iterations (O(max_sc_depth) steps)

For deep trees where backbone is long but side chains are short, this can be significantly faster than the standard forward().

Returns:

Reconstructed Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

class torchref.model.MixedModel(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]

Bases: DeviceMixin, Module

Model wrapper combining N ModelFT objects with learnable fractions.

Computes: F_mixed = Σ w_i * F_i where w_i are learnable weights constrained to sum to 1 via softmax.

This is useful for time-resolved crystallography where the crystal contains a mixture of conformational states (e.g., dark and light states) with unknown or refinable population fractions.

Parameters:
  • models (List[ModelFT]) – List of ModelFT objects to combine. All models must have compatible cell parameters and space groups.

  • initial_fractions (List[float], optional) – Initial population fractions for each model. Must sum to 1.0. If None, equal fractions are used (1/N for each model).

  • frozen_fractions (bool, optional) – If True, fractions are not updated during optimization. Default is False.

  • verbose (int, optional) – Verbosity level. Default is 0.

models

Constituent ModelFT objects (proper submodule registration).

Type:

nn.ModuleList

fraction_params

Raw parameters for fraction computation (softmax applied).

Type:

nn.Parameter

Examples

Create a mixed model with two states:

model_dark = ModelFT().load_pdb('dark.pdb')
model_light = ModelFT().load_pdb('light.pdb')

# 70% dark, 30% light
mixed = MixedModel([model_dark, model_light], initial_fractions=[0.7, 0.3])

# Compute mixed structure factors
F_mixed = mixed(hkl)

# Get current fractions
print(mixed.fractions)  # tensor([0.7, 0.3])
__init__(models, initial_fractions=None, frozen_fractions=False, verbose=0, device=None)[source]

Initialize MixedModel.

Parameters:
  • models (List[ModelFT]) – List of ModelFT objects to combine.

  • initial_fractions (List[float], optional) – Initial population fractions. Must sum to 1.0.

  • frozen_fractions (bool, optional) – If True, fractions are frozen. Default is False.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • device (torch.device, optional) – Device to place the model and parameters on. If None, infers from the first model’s device.

Raises:

ValueError – If models list is empty, fractions don’t match model count, fractions don’t sum to 1, or models have incompatible parameters.

property fractions: Tensor

Get normalized population fractions.

Returns:

Population fractions that sum to 1.0, shape (n_models,).

Return type:

torch.Tensor

property cell

Unit cell from first model (for compatibility).

property spacegroup

Space group from first model (for compatibility).

property device

Device from first model (for compatibility).

property dtype_float

Float dtype from first model (for compatibility).

property real_space_grid: Tensor | None

Real-space coordinate grid from first model (shared cell → same grid).

property fft

SfFFT submodule from first model (for gridsize access).

property gridsize: Tensor | None

Grid dimensions (nx, ny, nz) from first model.

property map_symmetry

Map symmetry operator from first model.

property inv_fractional_matrix: Tensor

Inverse fractionalization (orthogonalization) matrix.

property fractional_matrix: Tensor

Fractionalization matrix.

setup_grid(max_res=None, gridsize=None)[source]

Setup real-space grid on all constituent models.

All models share the same cell/spacegroup, so the grid is identical across all of them. This ensures each model’s SfFFT is ready for density map calculations.

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms.

  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz).

get_radius(min_radius_Angstrom=4.0)[source]

Get the radius in voxels for density calculation.

Delegates to first model (same grid → same voxel size).

Parameters:

min_radius_Angstrom (float, optional) – Minimum radius in Angstroms. Default is 4.0.

Returns:

Radius in voxels.

Return type:

int

build_complete_map()[source]

Build the mixed electron density map as the weighted sum of constituent model density maps.

density_mixed = Σ w_i * density_i

Each constituent model builds its own density map on the shared grid, and the results are combined using the current population fractions.

Returns:

Electron density map with shape (nx, ny, nz).

Return type:

torch.Tensor

freeze_fractions()[source]

Exclude fractions from optimization.

This prevents the population fractions from being updated during training while still allowing the constituent models to be refined.

unfreeze_fractions()[source]

Include fractions in optimization.

This allows the population fractions to be updated during training.

forward(hkl, recalc=False)[source]

Compute weighted sum of structure factors from all models.

f_mixed = Σ w_i * f_i

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • recalc (bool, optional) – If True, force recalculation of structure factors. Default is False.

Returns:

Mixed complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

get_individual_fcalc(hkl, recalc=True)[source]

Get structure factors from each model individually.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • recalc (bool, optional) – If True, force recalculation. Default is True.

Returns:

List of structure factor tensors, one per model.

Return type:

List[torch.Tensor]

copy()[source]

Create a deep copy of the MixedModel.

Returns:

A new MixedModel instance with copied models and parameters.

Return type:

MixedModel

__repr__()[source]

String representation.

write_ihm(filepath, state_names=None, group_name='ensemble', datasets=None)[source]

Write this MixedModel to an IHM mmCIF file.

Creates a single model group with the current population fractions over the constituent structural states.

Requires the optional python-ihm dependency.

Parameters:
  • filepath (str) – Output file path.

  • state_names (list of str, optional) – Names for each constituent model / state. Default: state_1, state_2, …

  • group_name (str) – Name for the model group. Default "ensemble".

  • datasets (dict of str -> ReflectionData, optional) – Per-timepoint reflection data to embed in the CIF.

get_vdw_radii()[source]

Get van der Waals radii from the first model.

Returns:

Van der Waals radii tensor.

Return type:

torch.Tensor

xyz()[source]

Get atomic coordinates from the first model.

Returns:

Atomic coordinates tensor.

Return type:

torch.Tensor

class torchref.model.Model(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]

Bases: DeviceMixin, DebugMixin, Module

Base model class for atomic structure models using PyTorch.

This class provides the foundation for managing atomic structure data including coordinates, atomic displacement parameters (ADPs), and occupancies. It supports both empty initialization for state_dict loading and file-based initialization from PDB/CIF files.

Parameters:
  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.

  • verbose (int, optional) – Verbosity level for logging. Default is 1.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.

xyz

Atomic coordinates tensor with shape (n_atoms, 3).

Type:

MixedTensor

adp

Atomic displacement parameters (isotropic B-factors) with shape (n_atoms,).

Type:

PositiveMixedTensor

u

Anisotropic displacement parameters with shape (n_atoms, 6).

Type:

MixedTensor

occupancy

Atomic occupancies with values in [0, 1].

Type:

OccupancyTensor

pdb

DataFrame containing atomic model data.

Type:

pandas.DataFrame

cell

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

Type:

Cell

spacegroup

Space group object.

Type:

gemmi.SpaceGroup

symmetry

Symmetry operations handler for this space group.

Type:

Symmetry

initialized

Whether the model has been initialized with data.

Type:

bool

Examples

Empty initialization for state_dict loading:

model = Model()
model.load_state_dict(torch.load('model.pt'))

File-based initialization:

model = Model()
model.load_pdb('structure.pdb')
__init__(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]

Initialize an empty Model shell.

Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().

Parameters:
  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.

  • verbose (int, optional) – Verbosity level for logging. Default is 1.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.

__bool__()[source]

Return the initialization status when used in boolean context.

property exclude_H_from_sf: bool

Whether to exclude hydrogen atoms from structure factor calculation.

When True, H atoms are excluded from get_iso() / get_aniso() so they do not contribute to Fcalc. They still participate in geometry and VDW restraints. Default is False.

property cell: Cell | None

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

Returns:

The unit cell object, or None if not set.

Return type:

Cell or None

property spacegroup: SpaceGroup | None

Space group object.

Returns:

The space group object, or None if not set.

Return type:

gemmi.SpaceGroup or None

property symmetry: SpaceGroup | None

Symmetry operations handler for this space group.

Returns the same SpaceGroup object as self.spacegroup — the separate Symmetry wrapper was redundant since Symmetry is just an alias.

Returns:

The space group object, or None if not set.

Return type:

SpaceGroup or None

property inv_fractional_matrix: Tensor

Fractionalization matrix B^-1 (Cartesian -> fractional).

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) fractionalization matrix.

Return type:

torch.Tensor

property fractional_matrix: Tensor

Orthogonalization matrix B (fractional -> Cartesian).

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) orthogonalization matrix.

Return type:

torch.Tensor

property recB: Tensor

Reciprocal basis matrix with [a*, b*, c*] as rows.

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) matrix where rows are the reciprocal basis vectors.

Return type:

torch.Tensor

property Z: Tensor

Atomic numbers for all atoms.

Returns:

Tensor of atomic numbers with shape (n_atoms,).

Return type:

torch.Tensor

get_P1_parameters_iso()[source]

Get model parameters transformed to P1 space for optimization.

This is useful for optimizers that do not handle symmetry directly or MD.

Returns:

  • xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.

  • adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.

  • occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.

  • A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.

  • B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

get_MD_parameters()[source]

Get model parameters prepared for molecular dynamics simulation.

Returns all P1-expanded parameters plus atomic numbers for MD engines.

Returns:

  • xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.

  • adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.

  • occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.

  • A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.

  • B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.

  • Z_p1 (torch.Tensor) – Atomic numbers expanded to P1 space.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

property parametrization

(A, B)}.

The parametrization is built lazily on first access.

Returns:

Dictionary mapping element symbols to tuples of (A, B) tensors.

Return type:

dict

Type:

ITC92 parametrization dictionary {element

get_scattering_params_iso()[source]

Get ITC92 scattering parameters (A, B) for isotropic atoms.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_iso_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_iso_atoms, 5).

get_scattering_params_aniso()[source]

Get ITC92 scattering parameters (A, B) for anisotropic atoms.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_aniso_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_aniso_atoms, 5).

set_restraints_cif(cif_path)[source]

Set CIF path for lazy restraint building.

Parameters:
  • cif_path (str or list of str) – Path(s) to CIF restraints dictionary file(s).

  • self (return) – For method chaining

property restraints

Lazy restraints property.

The restraints are built on first access using the model’s pdb DataFrame and the CIF path set via set_restraints_cif().

Returns:

The restraints object containing bond, angle, torsion, etc. restraints.

Return type:

RestraintsNew

bond_deviations()[source]

Compute bond length deviations using current xyz coordinates.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.

  • sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.

angle_deviations()[source]

Compute angle deviations using current xyz coordinates.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected angles in radians.

  • sigmas (torch.Tensor) – Standard deviations in radians.

torsion_deviations_with_sigmas()[source]

Compute torsion deviations (wrapped for periodicity) and sigmas.

Returns:

  • deviations_rad (torch.Tensor) – Wrapped deviations in radians.

  • sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).

load(reader)[source]
load_pdb(file)[source]

Load atomic model from PDB file.

Parameters:

file (str) – Path to PDB file.

Returns:

Self, for method chaining.

Return type:

Model

load_cif(file)[source]

Load atomic model from mmCIF file.

Parameters:

file (str) – Path to CIF/mmCIF file.

Returns:

Self, for method chaining.

Return type:

Model

property chain_sequences: List[Tuple[str, str]]

Per-chain amino acid sequences as single-letter codes.

Excludes HETATM records. Gaps in residue numbering are filled with ?. Non-standard residues are mapped to X.

Returns:

Ordered list of (chain_id, sequence_string). E.g. [("A", "MKVL??GAST"), ("B", "ACDEFG")].

Return type:

list of (str, str)

get_chain_residues()[source]

Per-chain residue names as 3-letter codes (for IHM/CIF writing).

Excludes HETATM records. Unlike chain_sequences, returns the raw 3-letter codes without gap filling.

Returns:

Ordered list of (chain_id, [resname, ...]).

Return type:

list of (str, list of str)

update_pdb()[source]
get_vdw_radii()[source]

Get van der Waals radii for all atoms based on their elements.

Caches the result in self.vdw_radii for future calls.

Returns:

Van der Waals radii for each atom with shape (n_atoms,).

Return type:

torch.Tensor

to(*args, **kwargs)[source]

Move Model and rebuild device-specific SF indices.

Delegates to DeviceMixin, which walks self.__dict__ (picking up self.cell, self.altloc_pairs, self._restraints and all registered parameters / buffers), refreshes the self.device tracker, and invalidates caches. Afterwards this override rebuilds the precomputed SF indices on the new device.

copy()[source]

Create a deep copy of the Model.

Creates a complete independent copy including all registered buffers, module parameters, PDB DataFrame, and spacegroup information.

Returns:

A new Model instance with copied data.

Return type:

Model

Examples

model = Model().load_pdb('structure.pdb')
model_copy = model.copy()
# model_copy is independent, changes won't affect model
write_pdb(filename, metadata=None)[source]

Write model to PDB file with optional metadata header.

Parameters:
  • filename (str) – Output PDB file path.

  • metadata (RefinementMetadata, optional) – Metadata to render as PDB header (REMARK 3, TITLE, etc.).

write_cif(filename, metadata=None)[source]

Write model to mmCIF file with optional metadata.

Parameters:
  • filename (str) – Output mmCIF file path.

  • metadata (RefinementMetadata, optional) – Metadata to include (refinement statistics, title, etc.).

get_iso()[source]

Return per-atom parameters for the isotropic atom subset.

Selects atoms whose ADP is a single scalar b (i.e. not anisotropic). The subset is defined by ~self.aniso_flag — intersected with self._heavy_atom_mask when _exclude_H_from_sf is enabled — and is precomputed as self._iso_indices at init / whenever the mask changes.

Returns:

  • xyz (torch.Tensor, shape (n_iso, 3)) – Cartesian coordinates of the isotropic atoms (Å).

  • adp (torch.Tensor, shape (n_iso,)) – Isotropic B-factors (Ų).

  • occupancy (torch.Tensor, shape (n_iso,)) – Occupancies in [0, 1].

Notes

When every atom is isotropic and no H exclusion is active — self._iso_covers_all is True, the common protein-refinement case — the per-atom indexing is skipped and self.xyz(), self.adp(), self.occupancy() are returned directly.

Motivation: self.xyz()[idx] is a no-op forward when idx = arange(N), but its backward routes through PyTorch’s aten::_index_put_impl_(accumulate=True), which performs a cub::DeviceRadixSortOnesweepKernel over len(idx) indices followed by a deduplicated scatter (~50-150 µs/iter per gather on A100 / 1DAW). Skipping the gather avoids that cost.

set_default_masks()[source]
PARAM_TYPES: Tuple[str, ...] = ('xyz', 'adp', 'u', 'occupancy')
parameters_of_types(types)[source]

Return the leaf ``nn.Parameter``s for the named parameter types.

Used by refinement entry points (refine_xyz, refine_adp, …) to construct an optimizer over only the leaves the caller intends to update. LossState.step then uses the optimizer’s param groups as intent and disables requires_grad on any other leaves the loss also touches.

Parameters:

types (Iterable[str]) – Subset of Model.PARAM_TYPES: "xyz", "adp", "u", "occupancy". Unknown names are silently skipped.

Returns:

The refinable_params leaf for each requested type, in the order the types were given.

Return type:

list of nn.Parameter

freeze(target)[source]
freeze_all()[source]
unfreeze_all()[source]
unfreeze(target)[source]
update_mask_from_selection(selection_string, target, mode='set', freeze=True)[source]

Update the refinable mask for a parameter using Phenix-style selection syntax.

This method updates the internal mask buffer (xyz_mask, adp_mask, u_mask, or occupancy_mask) based on the selection. The updated mask is NOT automatically applied to the parameter tensors - use apply_mask_to_parameter() to apply it.

Parameters:
  • selection_string (str) – Phenix-style selection string (see parse_phenix_selection docs).

  • target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.

  • mode (str, optional) – How to combine with current mask: - ‘set’: Replace mask with selection (default) - ‘add’: Add selection to current mask - ‘remove’: Remove selection from current mask

  • freeze (bool, optional) – If True (default), selected atoms will be frozen (mask=False). If False, selected atoms will be unfrozen (mask=True).

Raises:

ValueError – If target is not recognized or selection syntax is invalid.

Examples

# Freeze chain A coordinates
model.update_mask_from_selection("chain A", "xyz", mode='set', freeze=True)
model.apply_mask_to_parameter("xyz")

# Unfreeze backbone atoms
model.update_mask_from_selection("name CA or name C or name N", "xyz", freeze=False)
model.apply_mask_to_parameter("xyz")
apply_mask_to_parameter(target)[source]

Apply the current mask buffer to the parameter tensor.

Takes the current state of the mask buffer (xyz_mask, adp_mask, etc.) and applies it to the corresponding parameter tensor’s refinable mask.

Parameters:

target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.

Raises:

ValueError – If target is not recognized.

Examples

model.update_mask_from_selection("chain A", "xyz", freeze=True)
model.apply_mask_to_parameter("xyz")
freeze_selection(selection_string, targets='all')[source]

Freeze atoms matching a Phenix-style selection for specified parameters.

Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • targets (str or list of str, optional) – Parameter(s) to freeze. Can be: - ‘all’: Freeze xyz, adp, u, and occupancy (default) - str: Single parameter (‘xyz’, ‘adp’, ‘u’, ‘occupancy’) - list: List of parameters, e.g., [‘xyz’, ‘adp’]

Examples

# Freeze all parameters for chain A
model.freeze_selection("chain A", targets='all')

# Freeze only coordinates for residues 10-20
model.freeze_selection("resseq 10:20", targets='xyz')
unfreeze_selection(selection_string, targets='all')[source]

Unfreeze atoms matching a Phenix-style selection for specified parameters.

Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • targets (str or list of str, optional) – Parameter(s) to unfreeze. Can be: - ‘all’: Unfreeze xyz, adp, u, and occupancy (default) - str: Single parameter (‘xyz’, ‘adp’, ‘u’, ‘occupancy’) - list: List of parameters, e.g., [‘xyz’, ‘adp’]

Examples

# Unfreeze all parameters for chain A
model.unfreeze_selection("chain A", targets='all')

# Unfreeze only coordinates for backbone atoms
model.unfreeze_selection("name CA or name C or name N", targets='xyz')
get_aniso()[source]

Return per-atom parameters for the anisotropic atom subset.

Selects atoms whose ADP is the 6-element anisotropic tensor u = (u11, u22, u33, u12, u13, u23). The subset is defined by self.aniso_flag — intersected with self._heavy_atom_mask when _exclude_H_from_sf is enabled — and is precomputed as self._aniso_indices at init / whenever the mask changes.

Returns:

  • xyz (torch.Tensor, shape (n_aniso, 3)) – Cartesian coordinates of the anisotropic atoms (Å). Empty tensor when there are no anisotropic atoms.

  • u (torch.Tensor, shape (n_aniso, 6)) – Anisotropic U components (Ų) in the order (u11, u22, u33, u12, u13, u23). Empty when n_aniso == 0.

  • occupancy (torch.Tensor, shape (n_aniso,)) – Occupancies in [0, 1]. Empty when n_aniso == 0.

Notes

When there are no anisotropic atoms — self._aniso_is_empty is True, the common protein-refinement case — three empty placeholder tensors are returned without calling the MixedTensors at all. This avoids both the wrapped forward .clone() and the slow aten::_index_put_impl_ backward path that the self.xyz()[idx] gather would otherwise generate (see get_iso() for the same rationale).

parameters(recurse=True)[source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
named_mixed_tensors()[source]

Iterate over all MixedTensor attributes with their names.

Yields:

Tuple of (name, MixedTensor)

print_parameters_info()[source]

Print information about all MixedTensor parameters.

register_alternative_conformations()[source]

Identify and register all alternative conformation groups in the structure.

For each residue that has alternative conformations (altloc A, B, C, etc.), this method identifies all atoms belonging to each conformation and stores their indices as tensors in a tuple.

The result is stored in self.altloc_pairs as a list of tuples, where each tuple contains tensors of atom indices for each alternative conformation.

Examples

For a residue with conformations A and B:

# Conformation A has atoms at indices [100, 101, 102, ...]
# Conformation B has atoms at indices [110, 111, 112, ...]
# Result: [(tensor([100, 101, 102, ...]), tensor([110, 111, 112, ...])), ...]

For a residue with conformations A, B, C:

# Result: [(tensor([200, 201, ...]), tensor([210, 211, ...]), tensor([220, 221, ...])), ...]
shake_coords(stddev)[source]

Apply random Gaussian noise to atomic coordinates.

Perturbs the atomic coordinates by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.

Parameters:

stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstroms.

shake_adp(stddev)[source]

Apply random Gaussian noise to ADPs (atomic displacement parameters).

Perturbs the ADPs by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.

Parameters:

stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstrom^2.

generate_hydrogens(mon_lib_path=None)[source]

Generate hydrogen atoms for the current model using gemmi.

Places hydrogens at ideal geometry using the CCP4 monomer library and gemmi’s topology engine. Returns a new Model instance with hydrogens added; the original model is not modified.

Parameters:

mon_lib_path (str, optional) – Path to CCP4 monomer library directory. If None, uses the monomer library bundled with torchref (covers standard amino acids and common small molecules).

Returns:

A new Model instance with hydrogen atoms added (strip_H=False). Unknown residues are skipped silently.

Return type:

Model

Notes

Requires gemmi (already a torchref dependency). Heavy-atom coordinates from the current model state are used, so call this after any coordinate changes you want reflected in the H positions.

Examples

>>> model_no_h = Model().load_pdb('structure.pdb')
>>> model_with_h = model_no_h.generate_hydrogens()
>>> print(model_with_h.Z.shape)   # more atoms than model_no_h
strip_altlocs()[source]

Return a new model with alternate conformations removed.

For each residue that has multiple altlocs, the conformer with highest average occupancy is kept (ties broken alphabetically). The altloc column is cleared to "" in the returned model. The original model is not modified.

strip_hydrogens()[source]

Return a new model with hydrogen atoms removed.

The returned model has consistent DataFrame and tensors (xyz, adp, occupancy) with H atoms excluded. The original model is not modified.

Returns:

New model without hydrogen atoms.

Return type:

Model

hydrogenate(verbose=0, optimize=False, lbfgs_steps=3, max_iter=20)[source]

Return a new model with hydrogen atoms placed via Kabsch alignment.

Uses torchref’s monomer library to identify missing H atoms, places them by SVD-aligning ideal monomer coordinates onto the current model coordinates, then corrects each H to sit at ideal bond length from its parent atom. The original model is not modified.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=summary, 2=detailed). Default 0.

  • optimize (bool, optional) – If True, run a short LBFGS geometry optimization on H positions after placement. Default False (Kabsch placement only).

  • lbfgs_steps (int, optional) – Number of LBFGS outer steps (only when optimize=True). Default 3.

  • max_iter (int, optional) – Max line-search iterations per LBFGS step. Default 20.

Returns:

New model with hydrogen atoms added. All parameters are unfrozen in the returned model.

Return type:

Model

adp_loss()[source]

Compute the ADP regularization loss.

This loss encourages ADPs to have similar values across the structure, helping to prevent overfitting during refinement.

Returns:

Scalar tensor representing the ADP loss.

Return type:

torch.Tensor

adp_nll_loss(target_log_std=0.2)[source]

Compute negative log-likelihood of ADPs assuming Gaussian distribution in log-space.

This regularization penalizes ADPs that deviate from a target distribution with a FIXED standard deviation (hyperparameter), avoiding circular dependency on the current distribution’s statistics.

The NLL for a Gaussian distribution in log-space is:

NLL = 0.5 * mean[(log_adp - mu)^2 / sigma^2 + log(2*pi*sigma^2)]

Where mu is the mean of log-space ADPs (computed from current data) and sigma is the FIXED target standard deviation (hyperparameter).

Parameters:

target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. - 0.1 = very tight (ADPs within ~10% of mean) - 0.2 = moderate spread (ADPs within ~20% of mean) [RECOMMENDED] - 0.3 = looser spread (ADPs within ~30% of mean)

Returns:

Scalar tensor representing the NLL. Lower values indicate the distribution is closer to the target Gaussian with fixed sigma.

Return type:

torch.Tensor

Examples

# During refinement
structure_factor_loss = compute_structure_factor_loss()
nll_reg = model.adp_nll_loss(target_log_std=0.2)
total_loss = structure_factor_loss + 0.01 * nll_reg
total_loss.backward()

Notes

Uses FIXED sigma (no circular dependency on current distribution). Smaller target_log_std = stronger regularization (tighter distribution).

adp_nll_loss_per_atom(target_log_std=0.2)[source]

Compute per-atom negative log-likelihood for ADPs in log-space.

Returns the NLL contribution for each individual atom, useful for identifying outliers or applying atom-specific regularization weights.

The per-atom NLL is:

NLL_i = 0.5 * [(log_adp_i - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
Parameters:

target_log_std (float, optional) – Fixed target standard deviation in log-space. Default is 0.2.

Returns:

Tensor of shape (n_atoms,) with per-atom NLL values. Higher values indicate atoms farther from the mean.

Return type:

torch.Tensor

Examples

# Get per-atom NLL
atom_nll = model.adp_nll_loss_per_atom(target_log_std=0.2)
# Identify outlier atoms (high NLL)
threshold = atom_nll.mean() + 2 * atom_nll.std()
outliers = atom_nll > threshold
adp_kl_divergence_loss(target_log_std=0.2)[source]

Compute KL divergence between log ADP distribution and target Gaussian.

Measures how different the current log ADP distribution is from a target Gaussian distribution with the current mean of log ADPs and a fixed target standard deviation.

KL divergence formula for two Gaussians with same mean:

KL(q || p) = log(sigma_target/sigma_data) + sigma_data^2 / (2*sigma_target^2) - 0.5
Parameters:

target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. Controls how tightly ADPs should cluster.

Returns:

Scalar KL divergence value (always >= 0). 0 means distributions match perfectly. Higher values mean more deviation from target.

Return type:

torch.Tensor

Examples

# Use in loss function
loss = xray_loss + w_adp * model.adp_kl_divergence_loss(0.2)

Notes

Lower target_log_std = stronger regularization (tighter distribution). Mean is detached so it adapts to the natural scale of the data.

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the Model.

Includes all registered buffers, model parameters (xyz, b, u, occupancy), PDB DataFrame, and metadata (spacegroup, device, dtype, etc.).

Parameters:
  • destination (dict, optional) – Optional dict to populate with state.

  • prefix (str, optional) – Prefix for parameter names. Default is ‘’.

  • keep_vars (bool, optional) – Whether to keep variables in computational graph. Default is False.

Returns:

Complete state dictionary.

Return type:

dict

save_state(path)[source]

Save the complete state of the model to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the model from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, optional) – Whether to strictly enforce that keys match. Default is True.

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]

Create a fully initialized Model from a state dictionary.

This is the recommended way to restore a Model from a saved state. Creates an instance with properly initialized submodules, then loads the state.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • dtype_float (torch.dtype, optional) – Float dtype for tensors. Defaults to the configured dtypes.float.

Returns:

Fully initialized instance with restored state.

Return type:

Model

get_selection_mask(selection)[source]

Return a boolean mask for atoms matching a Phenix-style selection.

This is a convenience method that wraps parse_phenix_selection() to return a mask that can be used directly with MixedTensor.set() or other operations requiring atom selection.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

Boolean tensor of shape (n_atoms,) where True indicates selected atoms.

Return type:

torch.Tensor

Raises:

Examples

model = Model().load_pdb('structure.pdb')
# Get mask for chain A
mask = model.get_selection_mask("chain A")
# Use mask to update coordinates
new_coords = model.xyz()[mask] + translation
model.xyz.set(new_coords, mask)
# Get mask for backbone atoms
backbone_mask = model.get_selection_mask("name CA or name C or name N or name O")
# Complex selection with parentheses
mask = model.get_selection_mask("chain A and (resname ALA or resname GLY)")
select(selection)[source]

Return a new Model containing only atoms matching the Phenix-style selection.

Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.

Return type:

Model

Raises:
  • RuntimeError – If the model has not been initialized.

  • ValueError – If selection syntax is invalid or no atoms are selected.

Examples

model = Model().load_pdb('structure.pdb')
# Select chain A
chain_a = model.select("chain A")
# Select backbone atoms
backbone = model.select("name CA or name C or name N or name O")
# Select residues 10-50 of chain B
region = model.select("chain B and resseq 10:50")
# Select all except water
no_water = model.select("not resname HOH")
# Complex selection with parentheses
complex_sel = model.select("chain A and (resname ALA or resname GLY)")

Notes

This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.

xyz_fractional()[source]

Return atomic coordinates in fractional space.

Converts Cartesian coordinates to fractional coordinates using the inverse fractional matrix.

Returns:

Tensor of shape (n_atoms, 3) with fractional coordinates.

Return type:

torch.Tensor

rotate(rotation_matrix, center=None)[source]

Apply rotation to atomic coordinates (in-place).

Rotates all atoms around a specified center point. The rotation is applied using the formula: xyz_new = R @ (xyz - center) + center

Parameters:
  • rotation_matrix (torch.Tensor) – 3x3 rotation matrix. Should be orthogonal (R^T @ R = I).

  • center (torch.Tensor, optional) – Center of rotation with shape (3,). If None, uses the centroid of all atomic coordinates.

Returns:

Self, for method chaining.

Return type:

Model

Examples

# Rotate 90 degrees around Z-axis
import math
angle = math.pi / 2
R = torch.tensor([
    [math.cos(angle), -math.sin(angle), 0],
    [math.sin(angle), math.cos(angle), 0],
    [0, 0, 1]
])
model.rotate(R)

# Rotate around a specific point
center = torch.tensor([10.0, 20.0, 30.0])
model.rotate(R, center=center)
translate(translation, fractional=False)[source]

Apply translation to atomic coordinates (in-place).

Translates all atoms by a specified vector. The translation can be given in either Cartesian or fractional coordinates.

Parameters:
  • translation (torch.Tensor) – Translation vector with shape (3,).

  • fractional (bool, optional) – If True, the translation is interpreted as fractional coordinates and converted to Cartesian before applying. Default is False (translation is in Cartesian Angstroms).

Returns:

Self, for method chaining.

Return type:

Model

Examples

# Translate by 5 Angstroms along X
model.translate(torch.tensor([5.0, 0.0, 0.0]))

# Translate by half a unit cell along each axis
model.translate(torch.tensor([0.5, 0.5, 0.5]), fractional=True)
get_centroid()[source]

Compute the centroid (center of mass) of all atoms.

Returns:

Centroid coordinates with shape (3,).

Return type:

torch.Tensor

use_internal_coordinates(n_aa_per_segment=5, bond_cutoff=2.0, cif_dict=None, requires_grad=True)[source]

Switch xyz to segmented internal coordinate parametrization.

Replaces the current xyz MixedTensor with a SegmentedInternalCoordinateTensor that parametrizes atomic positions using bond lengths, angles, torsion angles, and per-segment rigid body parameters. The molecule is broken into independent segments to avoid the “lever arm problem” where small torsion changes near the root cause large displacements at distant atoms.

Parameters:
  • n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 5. - Smaller values (1-2): More segments, shallower trees, less lever arm - Larger values (5-10): Fewer segments, deeper trees, more lever arm

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0. Only used when cif_dict is not provided.

  • cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] DataFrame with ‘atom1’, ‘atom2’.

  • requires_grad (bool, optional) – Whether internal coordinate parameters should have gradients. Default is True.

Returns:

Self, for method chaining.

Return type:

Model

Examples

model = Model()
model.load_pdb('structure.pdb')
model.use_internal_coordinates(n_aa_per_segment=3)

# Now model.xyz() returns coordinates reconstructed from
# segmented internal coordinates

# Shake the structure using internal coordinates
new_xyz = model.xyz.shake(magnitude=0.1)

# Each segment has independent internal coordinates and
# rigid body parameters (position + orientation)

Notes

After calling this method, model.xyz will be a SegmentedInternalCoordinateTensor instead of a MixedTensor. This provides: - Shallow spanning trees within segments (depth ~10-30 vs ~1000) - Independent segments that don’t propagate changes to distant atoms - Rigid body parameters (position + orientation) per segment - forward() / __call__(): Reconstruct Cartesian coordinates - shake(magnitude): Add noise to internal parameters - Gradient flow through all internal coordinate parameters

class torchref.model.ModelFT(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]

Bases: CachedForwardMixin, Model

Model subclass for Fourier Transform-based electron density and structure factor calculations.

ModelFT extends the base Model class with capabilities for computing electron density maps in real space and structure factors via FFT. Uses ITC92 parametrization for electron density calculations.

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.

  • radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.

  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed from cell and max_res.

  • wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.

  • anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.

  • *args – Additional positional arguments passed to parent Model class.

  • **kwargs – Additional keyword arguments passed to parent Model class.

max_res

Maximum resolution for grid spacing.

Type:

float

radius_angstrom

Radius for density calculation.

Type:

float

wavelength

X-ray wavelength for anomalous scattering corrections.

Type:

float or None

anomalous_threshold

Threshold for significant anomalous scattering (electrons).

Type:

float

gridsize

Grid dimensions (nx, ny, nz).

Type:

torch.Tensor

real_space_grid

Real-space coordinate grid with shape (nx, ny, nz, 3).

Type:

torch.Tensor

map

Computed electron density map.

Type:

torch.Tensor or None

parametrization

ITC92 parametrization dictionary {element: (A, B, C)}.

Type:

dict

map_symmetry

Symmetry operator for map calculations.

Type:

MapSymmetry

Examples

Empty initialization for state_dict loading:

model = ModelFT()
model.load_state_dict(torch.load('model.pt'))

File-based initialization:

model = ModelFT(max_res=1.5)
model.load_pdb('structure.pdb')
__init__(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]

Initialize an empty ModelFT shell.

Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.

  • radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 4.0.

  • gridsize (tuple of int, optional) – Explicit grid size tuple (nx, ny, nz). If None, computed automatically.

  • wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.

  • anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.

  • *args – Passed to parent Model class.

  • **kwargs – Passed to parent Model class.

property cell

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

property spacegroup

Space group object.

load_pdb(filename)[source]

Load a PDB file and initialize the model with FT-specific setup.

Parameters:

filename (str) – Path to the PDB file.

Returns:

Self, for method chaining.

Return type:

ModelFT

select(selection)[source]

Return a new Model containing only atoms matching the Phenix-style selection.

Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.

Return type:

Model

Raises:
  • RuntimeError – If the model has not been initialized.

  • ValueError – If selection syntax is invalid or no atoms are selected.

Examples

model = Model().load_pdb('structure.pdb')
# Select chain A
chain_a = model.select("chain A")
# Select backbone atoms
backbone = model.select("name CA or name C or name N or name O")
# Select residues 10-50 of chain B
region = model.select("chain B and resseq 10:50")
# Select all except water
no_water = model.select("not resname HOH")
# Complex selection with parentheses
complex_sel = model.select("chain A and (resname ALA or resname GLY)")

Notes

This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.

load_cif(filename)[source]

Load a CIF file and initialize the model with FT-specific setup.

Parameters:

filename (str) – Path to the CIF/mmCIF file.

Returns:

Self, for method chaining.

Return type:

ModelFT

setup_gridsize(max_res=None)[source]

Compute optimal grid dimensions.

Delegates to FFT.compute_grid_size().

Parameters:

max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.

Returns:

Grid dimensions (nx, ny, nz) as int32 tensor.

Return type:

torch.Tensor

property A: Tensor

ITC92 A parameters (amplitudes) for all atoms.

Returns:

A parameters with shape (n_atoms, 5).

Return type:

torch.Tensor

property B: Tensor

ITC92 B parameters (widths) for all atoms.

Returns:

B parameters with shape (n_atoms, 5).

Return type:

torch.Tensor

property gridsize: Tensor | None

Grid dimensions (nx, ny, nz).

property real_space_grid: Tensor | None

Real-space coordinate grid with shape (nx, ny, nz, 3).

property voxel_size: Tensor | None

Voxel dimensions.

property map_symmetry: MapSymmetry | None

Symmetry operator for map calculations.

get_iso()[source]

Get isotropic atoms with their ITC92 parameters.

Returns:

  • xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).

  • adp (torch.Tensor) – Atomic displacement parameters (isotropic) with shape (n_atoms,).

  • occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).

get_aniso()[source]

Get anisotropic atoms with their ITC92 parameters.

Returns:

  • xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).

  • u (torch.Tensor) – Anisotropic U parameters with shape (n_atoms, 6).

  • occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).

setup_grid(max_res=None, gridsize=None)[source]

Setup real-space grid for electron density calculation.

Delegates to FFT.setup_grid() using the stored cell and spacegroup.

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. If None, uses self.max_res.

  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed automatically using Cell.compute_grid_size() and SpaceGroup.suggest_grid_size().

get_radius(min_radius_Angstrom=4.0)[source]

Get the radius in voxels used for density calculation around each atom.

Parameters:

min_radius_Angstrom (float, optional) – Minimum radius in Angstroms. Default is 4.0.

Returns:

Radius in voxels.

Return type:

int

build_complete_map(radius=None, apply_symmetry=True)[source]

Build electron density map from all atoms.

Uses get_iso() and get_aniso() to get atom data and constructs the complete electron density map.

Parameters:
  • radius (int, optional) – Radius in voxels around each atom to compute density. If None, uses self.radius.

  • apply_symmetry (bool, optional) – If True and space group is not P1, apply symmetry operations to the map. Default is True.

Returns:

Electron density map with symmetry applied if requested.

Return type:

torch.Tensor

build_initial_map(apply_symmetry=True)[source]

Build electron density map from atomic parameters.

Delegates to FFT.build_density_map() using the model’s stored parameters.

Parameters:

apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.

Returns:

Electron density map with shape (nx, ny, nz).

Return type:

torch.Tensor

save_map(filename)[source]

Save the electron density map to a CCP4 format file.

Parameters:

filename (str) – Output filename for the map.

Raises:

ValueError – If no map has been computed yet.

get_map_statistics()[source]

Get statistics about the current density map.

rebuild_map(radius=None)[source]

Rebuild the density map from scratch.

Convenience method that clears and rebuilds everything.

Parameters:

radius (int, optional) – Radius in voxels around each atom. If None, uses self.radius. If specified, overrides self.radius.

Returns:

Rebuilt electron density map.

Return type:

torch.Tensor

update_pdb()[source]

Update PDB with current atomic parameters.

reset_cache()[source]

Reset SF cache, anomalous cache, and all wrapper forward caches.

invalidate_cache()[source]

Alias for reset_cache().

get_structure_factor(hkl, recalc=False, apply_anomalous=True)[source]

Get structure factors for given hkl reflections.

Uses CachedForwardMixin to cache the result and auto-invalidate when parameters change or a backward pass propagates through.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • recalc (bool, optional) – If True, forces recalculation bypassing the cache. Default is False.

  • apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.

Returns:

Complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

Notes

The complete scattering factor is:

f(s, λ) = f₀(s) + f’(λ) + i·f’’(λ)

where f₀ is the normal (Thomson) scattering factor computed via FFT, and f’/f’’ are the wavelength-dependent anomalous corrections.

Anomalous corrections are only computed for atoms where |f'| > anomalous_threshold or |f''| > anomalous_threshold.

property fft

Access the SfFFT submodule.

forward(hkl, apply_anomalous=True)[source]

Compute structure factors for given hkl.

This is called by the mixin’s __call__ which handles caching, backward-hook registration, and auto-invalidation.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.

Returns:

Calculated complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

copy(detach=True)[source]

Create a deep copy of the ModelFT.

Creates a complete independent copy including all Model base class data, FFT submodule state (gridsize, real_space_grid, voxel_size, map_symmetry), ITC92 parametrization, and scalar attributes. Cache is reset to empty.

Parameters:

detach (bool, optional) – If True, the copy’s parameters will be detached from the computation graph (default: True).

Returns:

A new ModelFT instance with copied data.

Return type:

ModelFT

Examples

model = ModelFT().load_pdb('structure.pdb')
model_copy = model.copy()
# model_copy is independent, changes won't affect model
state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the ModelFT.

Extends parent Model.state_dict() with FT-specific parameters including max_res, radius_angstrom. Grid state is handled by the FFT submodule.

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, optional) – Prefix for parameter names. Default is ‘’.

  • keep_vars (bool, optional) – Whether to keep variables in computational graph. Default is False.

Returns:

Complete state dictionary.

Return type:

dict

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]

Create a fully initialized ModelFT from a state dictionary.

This is the recommended way to restore a ModelFT from a saved state. Creates an instance with properly initialized submodules, then loads the state.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • dtype_float (torch.dtype, optional) – Float dtype for tensors. Default is dtypes.float.

Returns:

Fully initialized instance with restored state.

Return type:

ModelFT

class torchref.model.MixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]

Bases: DeviceMixin, CachedForwardMixin, Module

A wrapper class for tensors with mixed fixed and refinable elements.

Stores a mask indicating which elements can be refined and maintains both fixed and refinable components separately. The full tensor is reconstructed on-the-fly when accessed.

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.

  • device (torch.device, optional) – Device for the tensor. Default is same as initial_values.

  • name (str, optional) – Optional name for this parameter (useful for debugging/logging).

refinable_mask

Boolean mask indicating refinable elements.

Type:

torch.Tensor

fixed_mask

Boolean mask indicating fixed elements (inverse of refinable_mask).

Type:

torch.Tensor

fixed_values

Buffer containing fixed values.

Type:

torch.Tensor

refinable_params

Parameter containing refinable values.

Type:

nn.Parameter

Examples

Empty initialization for state_dict loading:

mixed = MixedTensor()
mixed.load_state_dict(torch.load('mixed.pt'))

Full initialization with values:

mask = torch.zeros(100, dtype=torch.bool)
mask[20:30] = True
initial_values = torch.randn(100)
mixed = MixedTensor(initial_values, refinable_mask=mask, requires_grad=True)
optimizer = torch.optim.Adam([mixed.refinable_params], lr=0.01)
__init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None)[source]

Initialize a MixedTensor.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values for all elements. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined. If None, all elements are refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor. Default is same as initial_values.

  • device (torch.device, optional) – Device for the tensor. Default is same as initial_values.

  • name (str, optional) – Optional name for this parameter (useful for debugging/logging).

forward()[source]

Reconstruct and return the full tensor.

Three fast paths, in priority order:

  1. All atoms refinable — return refinable_params directly (no clone, no scatter). Common in standard refinement where no atoms are frozen. Saves a full-tensor clone + an index_put_ per call and replaces the index_put backward (sort + atomic scatter) with a no-op identity.

  2. No refinable atoms — return fixed_values.clone() (the clone preserves the caller-must-not-mutate contract even though the result is detached from autograd).

  3. Mixed — go through _AssembleMixedTensor, whose backward is a single index_select (gather) instead of PyTorch’s default index_put_ backward (radix-sort + scatter).

__getitem__(key)[source]

Get values at specified indices/mask from the full tensor.

Parameters:

key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing

Returns:

Selected values from the full tensor.

Return type:

torch.Tensor

Examples

model.b[5]           # Get B-factor for atom 5
model.b[5:10]        # Get B-factors for atoms 5-9
model.b[mask]        # Get B-factors where mask is True
model.xyz[:, 0]      # Get all x-coordinates

Notes

Subclasses may override _get_values() to customize value retrieval.

__setitem__(key, value)[source]

Set values at specified indices/mask.

This method updates both fixed_values and refinable_params at the specified positions. Supports various indexing styles including slices, boolean masks, and integer indices.

Parameters:
  • key (int, slice, torch.Tensor, or tuple) – Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing

  • value (torch.Tensor, float, or int) – Values to assign. Can be: - Scalar: Broadcast to all selected positions - Tensor: Must match the shape of selected region

Examples

model.b[:] = 30.0           # Set all B-factors to 30
model.b[5:10] = 25.0        # Set B-factors 5-9 to 25
model.b[mask] = new_values  # Set B-factors where mask is True
model.xyz[mask] = new_coords  # Set coordinates for masked atoms
model.xyz[:, 0] += 1.0      # Shift all x-coordinates (read-modify-write)

Notes

This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.

Subclasses may override _set_values() to customize value handling (e.g., PositiveMixedTensor converts to log-space).

set(values, mask)[source]

Set values at positions specified by a boolean mask.

Updates both fixed_values and refinable_params at the positions specified by the mask. This is useful for applying coordinate shifts, B-factor corrections, or any other updates to specific atoms.

Parameters:
  • values (torch.Tensor) –

    New values to assign. Shape must match: - For 1D tensors: (n_selected,) where n_selected = mask.sum() - For 2D tensors (e.g., xyz): (n_selected, d) where d is the

    second dimension size (e.g., 3 for coordinates)

  • mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.

Raises:

ValueError – If mask shape doesn’t match tensor’s first dimension, or if values shape doesn’t match the number of selected elements.

Examples

# Update coordinates for selected atoms
mask = model.get_selection_mask("chain A")
new_coords = original_coords[mask] + shift
model.xyz.set(new_coords, mask)

# Update B-factors for specific residues
mask = model.get_selection_mask("resseq 10:20")
new_b = torch.ones(mask.sum()) * 30.0
model.b.set(new_b, mask)

Notes

This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state.

property shape

Return the shape of the full tensor.

property dtype

Return the dtype of the tensor.

property device

Return the device of the tensor.

get_refinable_count()[source]

Return the number of refinable parameters.

get_fixed_count()[source]

Return the number of fixed parameters.

update_fixed_values(new_values)[source]

Update the fixed values (does not affect refinable parameters).

Parameters:

new_values (torch.Tensor) – New tensor values. Only fixed positions will be updated.

Raises:

ValueError – If new_values shape doesn’t match tensor shape.

update_refinable_mask(new_mask, reset_refinable=False)[source]

Update which elements are refinable.

This is an advanced operation that modifies the refinable/fixed split.

Parameters:
  • new_mask (torch.Tensor) – New boolean mask indicating refinable elements.

  • reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.

detach()[source]

Return a detached copy of the full tensor.

clone()[source]

Create a deep copy of this MixedTensor.

copy()[source]

Create a deep copy of this MixedTensor.

Creates a complete independent copy with all buffers and parameters. Alias for clone().

Returns:

New MixedTensor instance with copied data.

Return type:

MixedTensor

clip(min_value=None, max_value=None)[source]

Clip the full tensor values between min_value and max_value.

to(*args, **kwargs)[source]

Move via DeviceMixin and rebuild the index cache.

refine(selection, reset_values=False)[source]

Make a selection of the tensor refinable.

Parameters:
  • selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become refinable. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices

  • reset_values (bool, optional) – If True, reset the selected elements to their current fixed values before making them refinable. Default is False.

Examples

mixed.refine(slice(10, 20))  # Make elements 10-19 refinable
mixed.refine(mask)  # Make elements where mask is True refinable
fix(selection, freeze_at_current=True)[source]

Make a selection of the tensor fixed (non-refinable).

Parameters:
  • selection (slice, torch.Tensor, or tuple) – Selection indicating which elements should become fixed. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices

  • freeze_at_current (bool, optional) – If True (default), freeze the selected elements at their current values. If False, they revert to the original fixed values.

Examples

mixed.fix(slice(10, 20))  # Fix elements 10-19
mixed.fix(mask)  # Fix elements where mask is True
refine_all()[source]

Make all elements refinable.

fix_all(freeze_at_current=True)[source]

Make all elements fixed.

property name: str | None

Return the name of this parameter.

__str__()[source]

More detailed string representation.

parameters()[source]

Return refinable parameters for optimizer.

class torchref.model.PositiveMixedTensor(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]

Bases: MixedTensor

A MixedTensor subclass ensuring all values are positive via log-space parametrization.

Useful for parameters that must be strictly positive (e.g., B-factors, scale factors, sigma values). Values are stored as logarithms internally and converted to normal space via exp() when accessed.

Reparametrization:

internal_value = log(desired_value)
output_value = exp(internal_value)

This ensures output_value > 0 always, with smooth gradient flow.

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.

Examples

Empty initialization for state_dict loading:

b = PositiveMixedTensor()
b.load_state_dict(torch.load('b_factors.pt'))

Full initialization with values:

initial_b = torch.tensor([20.0, 30.0, 15.0])
b = PositiveMixedTensor(initial_b)
output = b()  # Returns exp(log_b) = positive values
assert (b() > 0).all()
__init__(initial_values=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, epsilon=0.1)[source]

Initialize a PositiveMixedTensor.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial tensor values in NORMAL space. Optional for empty init.

  • refinable_mask (torch.Tensor, optional) – Boolean mask indicating which elements can be refined.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • epsilon (float, optional) – Small value to add before taking log to avoid log(0). Default is 1e-1.

Raises:

ValueError – If any initial values are not positive.

forward()[source]

Return the full tensor in NORMAL space.

Applies exponential transformation to the log-space values.

Returns:

Tensor with positive values.

Return type:

torch.Tensor

fix(mask, freeze_at_current=True)[source]

Fix (freeze) specific elements.

Converts current normal-space values to log space for storage.

Parameters:
  • mask (torch.Tensor) – Boolean mask indicating which elements to fix.

  • freeze_at_current (bool, optional) – If True, freeze at current values. Default is True.

refine(mask)[source]

Make specific elements refinable.

Preserves current log-space values.

Parameters:

mask (torch.Tensor) – Boolean mask indicating which elements to make refinable.

set(values, mask)[source]

Set values at positions specified by a boolean mask.

Values are provided in NORMAL space (e.g., actual B-factors) and automatically converted to log-space for internal storage.

Parameters:
  • values (torch.Tensor) – New values to assign in NORMAL space (positive values). Shape must be (n_selected,) where n_selected = mask.sum().

  • mask (torch.Tensor) – Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values.

Raises:

ValueError – If mask shape doesn’t match tensor’s first dimension, if values shape doesn’t match the number of selected elements, or if any values are not positive.

Examples

# Update B-factors for selected atoms
mask = model.get_selection_mask("name CA")
new_b = torch.ones(mask.sum()) * 30.0  # Set CA B-factors to 30
model.b.set(new_b, mask)

Notes

This method modifies the tensor in-place. Values are automatically converted to log-space internally to maintain the positivity constraint.

get_log_values()[source]

Return the internal log-space representation.

Useful for debugging or when direct access to the parametrization space is needed.

Returns:

Tensor with log-space values.

Return type:

torch.Tensor

update_refinable_mask(new_mask, reset_refinable=False)[source]

Update which elements are refinable.

Properly handles log-space conversion.

Parameters:
  • new_mask (torch.Tensor) – New boolean mask indicating refinable elements.

  • reset_refinable (bool, optional) – If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False.

copy()[source]

Create a deep copy of this PositiveMixedTensor.

Properly handles the log-space reparametrization.

Returns:

New PositiveMixedTensor instance with copied data.

Return type:

PositiveMixedTensor

__str__()[source]

More detailed string representation.

class torchref.model.PassThroughTensor(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]

Bases: DeviceMixin, Module

A simple parameter wrapper that passes the parameter through unchanged.

Useful as a placeholder or for parameters that do not require any special handling.

Parameters:
  • initial_values (torch.Tensor) – Initial tensor values.

  • requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type of the tensor.

  • device (torch.device, optional) – Device to place the tensor on.

  • name (str, optional) – Optional name for the parameter.

__init__(initial_values, requires_grad=True, dtype=None, device=None, name=None)[source]

Initialize the PassThroughTensor.

Parameters:
  • initial_values (torch.Tensor) – Initial tensor values.

  • requires_grad (bool, optional) – Whether the parameter requires gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type of the tensor.

  • device (torch.device, optional) – Device to place the tensor on.

  • name (str, optional) – Optional name for the parameter.

forward()[source]

Return the parameter value unchanged.

Returns:

The parameter tensor.

Return type:

torch.Tensor

class torchref.model.OccupancyTensor(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]

Bases: MixedTensor

A specialized MixedTensor for handling occupancy parameters in crystallographic refinement.

Handles specific constraints and requirements for occupancy including value bounds [0, 1] via sigmoid reparameterization, atom sharing, and alternative conformation sum-to-1 constraints.

Features: - Values bounded between 0 and 1 using sigmoid reparameterization - Atoms can share occupancies (e.g., all atoms in a residue) - Alternative conformations automatically sum to 1.0 via normalization - Memory-efficient collapsed storage (one parameter per sharing group) - Fully vectorized collapse/expand operations

Parameters:
  • initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.

  • sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy.

  • altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Each tuple contains the atom indices for each conformation.

  • refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space).

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.

expansion_mask

Maps atoms to collapsed indices.

Type:

torch.Tensor

linked_occ_sizes

List of altloc group sizes present.

Type:

list

collapse_counts

Count of atoms per collapsed index.

Type:

torch.Tensor

Examples

sharing_groups = torch.tensor([0, 0, 1, 1, 2, 2])
occ = OccupancyTensor(
    initial_values=torch.tensor([1.0, 1.0, 0.7, 0.7, 0.3, 0.3]),
    sharing_groups=sharing_groups,
    altloc_groups=[([2, 3], [4, 5])],
)
result = occ()  # Atoms 2-3 and 4-5 will sum to 1.0
__init__(initial_values=None, sharing_groups=None, altloc_groups=None, refinable_mask=None, requires_grad=True, dtype=None, device=None, name=None, use_sigmoid=True)[source]

Initialize an OccupancyTensor with collapsed storage and altloc support.

If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • initial_values (torch.Tensor, optional) – Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init.

  • sharing_groups (torch.Tensor, optional) – Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy. Example: tensor([0, 0, 0, 1, 1, 2]) means atoms 0,1,2 share one occupancy, atoms 3,4 share another, and atom 5 is independent.

  • altloc_groups (list of tuple, optional) – List of tuples of atom index lists representing alternative conformations. Example: [([10,11], [12,13])] means atoms 10,11 (conf A) and 12,13 (conf B) are altlocs that sum to 1.0.

  • refinable_mask (torch.Tensor, optional) – Boolean mask for which ATOMS can be refined (in full tensor space). If any atom in a group is refinable, the entire group becomes refinable.

  • requires_grad (bool, optional) – Whether refinable parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for the tensor.

  • device (torch.device, optional) – Device for the tensor.

  • name (str, optional) – Optional name for this parameter.

  • use_sigmoid (bool, optional) – If True, use sigmoid parameterization to bound values to [0,1]. Default is True.

forward()[source]

Reconstruct full occupancy tensor with sigmoid and altloc constraints.

For alternative conformations, applies sigmoid then normalizes within each group to enforce sum-to-1 constraint.

Returns:

Full occupancy tensor with values in [0, 1] and shape (n_atoms,).

Return type:

torch.Tensor

property shape

Return the shape of the FULL tensor (not collapsed).

property collapsed_shape

Return the shape of the collapsed internal storage.

clamp(min_value=0.0, max_value=1.0)[source]

Clamp occupancy values to specified range and return a new OccupancyTensor.

Parameters:
  • min_value (float, optional) – Minimum occupancy value. Default is 0.0.

  • max_value (float, optional) – Maximum occupancy value. Default is 1.0.

Returns:

New OccupancyTensor with clamped values.

Return type:

OccupancyTensor

set_group_occupancy(group_idx, value)[source]

Set the occupancy for all atoms in a specific collapsed group.

Parameters:
  • group_idx (int) – Collapsed index of the group.

  • value (float) – Occupancy value to set (must be in [0, 1]).

Raises:

ValueError – If group_idx is out of range or value is not in [0, 1].

get_group_occupancy(group_idx)[source]

Get the current occupancy value for a collapsed group.

Parameters:

group_idx (int) – Collapsed index of the group.

Returns:

Current occupancy value for the group.

Return type:

float

Raises:

ValueError – If group_idx is out of range.

freeze(mask=None)[source]

Freeze occupancy parameters, making them non-refinable.

The mask is supplied in UNCOMPRESSED (full atom) form but freezing operates on the COMPRESSED data structure. This method handles the conversion.

Parameters:

mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to freeze. If None, freeze all parameters. Shape must be (n_atoms,).

Notes

If ANY atom in a sharing group is frozen, the ENTIRE group is frozen because all atoms in a group share the same compressed parameter.

Examples

# Freeze atoms 0-10 (in full atom space)
freeze_mask = torch.zeros(n_atoms, dtype=torch.bool)
freeze_mask[0:11] = True
occ.freeze(freeze_mask)
# Freeze all atoms
occ.freeze()
unfreeze(mask=None)[source]

Unfreeze occupancy parameters, making them refinable.

The mask is supplied in UNCOMPRESSED (full atom) form but unfreezing operates on the COMPRESSED data structure. This method handles the conversion.

Parameters:

mask (torch.Tensor, optional) – Boolean mask in FULL (uncompressed) atom space indicating which atoms to unfreeze. If None, unfreeze all parameters. Shape must be (n_atoms,).

Notes

If ANY atom in a sharing group is unfrozen, the ENTIRE group becomes refinable because all atoms in a group share the same compressed parameter.

Examples

# Unfreeze atoms 100-200 (in full atom space)
unfreeze_mask = torch.zeros(n_atoms, dtype=torch.bool)
unfreeze_mask[100:201] = True
occ.unfreeze(unfreeze_mask)
# Unfreeze all atoms
occ.unfreeze()
freeze_all()[source]

Freeze all occupancy parameters.

Convenience method equivalent to freeze(None).

unfreeze_all()[source]

Unfreeze all occupancy parameters.

Convenience method equivalent to unfreeze(None).

get_refinable_atoms()[source]

Get a boolean mask in FULL atom space indicating refinable atoms.

Returns:

Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is refinable (though it shares with others in its group).

Return type:

torch.Tensor

get_frozen_atoms()[source]

Get a boolean mask in FULL atom space indicating frozen atoms.

Returns:

Boolean tensor of shape (n_atoms,) where True indicates the atom’s occupancy is frozen.

Return type:

torch.Tensor

get_refinable_count()[source]

Get the number of refinable parameters in COMPRESSED space.

This is the number of refinable groups, not atoms. Use get_refinable_atoms().sum() to get the number of refinable atoms.

Returns:

Number of refinable compressed parameters.

Return type:

int

get_fixed_count()[source]

Get the number of fixed parameters in COMPRESSED space.

This is the number of fixed groups, not atoms. Use get_frozen_atoms().sum() to get the number of frozen atoms.

Returns:

Number of fixed compressed parameters.

Return type:

int

update_refinable_mask(new_mask, in_compressed_space=False)[source]

Directly update the refinable mask with a new mask.

Allows more direct control over which parameters are refinable, compared to freeze/unfreeze which modify the existing state.

Parameters:
  • new_mask (torch.Tensor) – Boolean tensor indicating which parameters should be refinable. If in_compressed_space=False: shape (n_atoms,) in full atom space. If in_compressed_space=True: shape (n_groups,) in compressed space.

  • in_compressed_space (bool, optional) – If True, new_mask is in compressed space. If False (default), new_mask is in full atom space and will be collapsed.

Examples

Full atom space:

atom_mask = torch.zeros(n_atoms, dtype=torch.bool)
atom_mask[:100] = True
occ.update_refinable_mask(atom_mask, in_compressed_space=False)

Compressed space:

group_mask = torch.zeros(n_groups, dtype=torch.bool)
group_mask[::2] = True
occ.update_refinable_mask(group_mask, in_compressed_space=True)
static from_residue_groups(initial_values, pdb_dataframe, refinable_mask=None, **kwargs)[source]

Create an OccupancyTensor where all atoms in each residue share occupancy.

Common use case where all atoms in a residue should have the same occupancy.

Parameters:
  • initial_values (torch.Tensor) – Initial occupancy values for all atoms.

  • pdb_dataframe (pandas.DataFrame) – DataFrame with PDB data (must have ‘resname’, ‘resseq’, ‘chainid’).

  • refinable_mask (torch.Tensor, optional) – Mask for refinable atoms.

  • **kwargs – Additional arguments passed to OccupancyTensor constructor.

Returns:

OccupancyTensor with residue-based sharing groups.

Return type:

OccupancyTensor

copy()[source]

Create a deep copy of this OccupancyTensor.

Creates a complete independent copy with all buffers and parameters, including sharing groups, altloc groups, and collapsed storage structures.

Returns:

New OccupancyTensor instance with copied data.

Return type:

OccupancyTensor

class torchref.model.SegmentedInternalCoordinateTensor(initial_xyz, pdb, n_aa_per_segment=3, bond_cutoff=2.0, cif_dict=None, requires_grad=True, dtype=None, device=None)[source]

Bases: DeviceMixin, CachedForwardMixin, Module

Parameter wrapper using segmented internal coordinates.

Stores: per-segment bond_lengths, angles, torsions, segment_positions, segment_orientations Reconstructs: Cartesian xyz on forward()

This provides a physically meaningful parametrization that avoids the lever arm problem by breaking the molecule into independent segments, each with shallow spanning trees and rigid body parameters.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’.

  • n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 3.

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.

  • device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.

n_atoms

Number of atoms.

Type:

int

n_segments

Number of segments.

Type:

int

max_depth

Maximum depth in any segment’s spanning tree.

Type:

int

bond_lengths

Bond length parameters in Angstroms.

Type:

nn.Parameter

angles

Angle parameters in radians.

Type:

nn.Parameter

torsions

Torsion angle parameters in radians.

Type:

nn.Parameter

segment_positions

Absolute positions of segment root atoms.

Type:

nn.Parameter

segment_orientations

ZYZ Euler angle orientations for each segment.

Type:

nn.Parameter

AA_NAMES = frozenset({'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'MSE', 'PHE', 'PRO', 'SEC', 'SER', 'THR', 'TRP', 'TYR', 'VAL'})
__init__(initial_xyz, pdb, n_aa_per_segment=3, bond_cutoff=2.0, cif_dict=None, requires_grad=True, dtype=None, device=None)[source]

Initialize SegmentedInternalCoordinateTensor.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’, ‘resname’.

  • n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 3.

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms (used as fallback). Default is 2.0.

  • cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] is a DataFrame with ‘atom1’ and ‘atom2’ columns.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.

  • device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.

property dtype

Return the dtype of tensors.

property device

Return the device of tensors.

forward()[source]

Reconstruct Cartesian xyz from internal coordinates.

Uses fully vectorized operations for maximum performance.

Returns:

Reconstructed Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

shake(magnitude=0.1)[source]

Add Gaussian noise to internal parameters.

Parameters:

magnitude (float, optional) – Standard deviation of Gaussian noise. Default is 0.1.

Returns:

New Cartesian coordinates after perturbation.

Return type:

torch.Tensor

fix(selection=None, freeze_at_current=True)[source]

Fix (freeze) atoms to use fixed xyz coordinates.

Parameters:
  • selection (torch.Tensor, slice, or None) – Boolean mask or indices of atoms to fix.

  • freeze_at_current (bool, optional) – If True, store current coordinates for selected atoms.

freeze(selection=None, freeze_at_current=True)[source]

Alias for fix().

refine(selection=None, rebuild=True)[source]

Make atoms refinable.

Parameters:
  • selection (torch.Tensor, slice, or None) – Boolean mask or indices of atoms to make refinable.

  • rebuild (bool, optional) – If True, rebuild internal coordinates from fixed_xyz.

unfreeze(selection=None, rebuild=True)[source]

Alias for refine().

fix_all(freeze_at_current=True)[source]

Fix all atoms.

freeze_all(freeze_at_current=True)[source]

Alias for fix_all().

refine_all(rebuild=True)[source]

Make all atoms refinable.

unfreeze_all(rebuild=True)[source]

Alias for refine_all().

property n_refinable: int

Return the number of refinable atoms.

property n_fixed: int

Return the number of fixed atoms.

class torchref.model.ClosedSegmentedInternalCoordinateTensor(initial_xyz, pdb, n_aa_per_segment=18, junction_size=3, bond_cutoff=2.0, cif_dict=None, prefer_loops=True, requires_grad=True, dtype=None, device=None)[source]

Bases: DeviceMixin, CachedForwardMixin, Module

Parameter wrapper using segmented internal coordinates with chain closure.

Stores: per-segment bond_lengths, angles, torsions, segment_positions, segment_orientations. Junction backbone torsions are slave DOFs solved by Newton’s method with IFT gradients.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’, ‘resname’.

  • n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 18.

  • junction_size (int, optional) – Number of residues per junction. Default is 3.

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.

  • cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type.

  • prefer_loops (bool, optional) – If True, slide junctions to prefer loop regions. Default is True.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors.

  • device (torch.device, optional) – Device for tensors.

__init__(initial_xyz, pdb, n_aa_per_segment=18, junction_size=3, bond_cutoff=2.0, cif_dict=None, prefer_loops=True, requires_grad=True, dtype=None, device=None)[source]
property dtype
property device
forward()[source]

Reconstruct Cartesian xyz from internal coordinates with chain closure.

Steps: 1. Place segment atoms (existing NeRF pipeline) 2. Solve junction closures (Newton + IFT) 3. Place junction backbone atoms 4. Place junction sidechain atoms 5. Apply frozen overlay

Returns:

Reconstructed Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

shake(magnitude=0.1)[source]

Add Gaussian noise to internal parameters.

Perturbs torsions, bond lengths, and bond angles. Segment positions and orientations are NOT perturbed because random independent translation of segments creates gaps that exceed the junction chain’s reach. During optimization, the optimizer adjusts these rigid-body DOFs smoothly.

fix(selection=None, freeze_at_current=True)[source]

Fix (freeze) atoms.

freeze(selection=None, freeze_at_current=True)[source]

Alias for fix().

refine(selection=None, rebuild=True)[source]

Make atoms refinable.

unfreeze(selection=None, rebuild=True)[source]

Alias for refine().

fix_all(freeze_at_current=True)[source]

Fix all atoms.

freeze_all(freeze_at_current=True)[source]

Alias for fix_all().

refine_all(rebuild=True)[source]

Make all atoms refinable.

unfreeze_all(rebuild=True)[source]

Alias for refine_all().

property n_refinable: int
property n_fixed: int
property closure_residuals: Tensor | None

Get last closure residuals from the junction solver.

property max_closure_gap: float

Maximum closure gap in Angstroms across all junctions.

class torchref.model.ModelCollection(base_models, dark_key='dark', verbose=0)[source]

Bases: DeviceMixin, Module

Named dictionary of MixedModel instances at different timepoints.

All timepoint models share the same base structural models (ModelFT objects stored once in an nn.ModuleList). Each timepoint gets its own independent fraction parameters via _SharedMixedModel.

Keys should match DatasetCollection keys so that collection-aware targets can automatically pair datasets with models.

Parameters:
  • base_models (List[ModelFT]) – The K shared structural models (e.g., ground state + intermediates).

  • dark_key (str) – Key for the dark / reference entry. Default "dark".

  • verbose (int) – Verbosity level.

Examples

models = ModelCollection([model_dark, model_light])
models.add_dark()                             # fractions=[1, 0]
models.add_timepoint("1ps", [0.9, 0.1])
models.add_timepoint("5ps", [0.7, 0.3])

# Access
mixed = models["1ps"]
fcalc = mixed(hkl)
print(mixed.fractions)
__init__(base_models, dark_key='dark', verbose=0)[source]
add_timepoint(name, fractions=None, frozen_fractions=False)[source]

Add a timepoint with given initial fractions.

Parameters:
  • name (str) – Timepoint identifier (should match DatasetCollection key).

  • fractions (List[float], optional) – Initial population fractions. If None, uses equal fractions.

  • frozen_fractions (bool) – If True, fractions are not updated during optimization.

Returns:

Self, for method chaining.

Return type:

ModelCollection

add_dark(fractions=None)[source]

Add the dark / reference entry.

Default fractions: [1, 0, 0, …] (100 % ground state).

Parameters:

fractions (List[float], optional) – Override dark fractions. Default is pure ground state.

Returns:

Self, for method chaining.

Return type:

ModelCollection

classmethod from_kinetics(base_models, occ_model, timepoint_names, dark_key='dark', verbose=0)[source]

Create a ModelCollection from a kinetics occupancy model.

Parameters:
  • base_models (List[ModelFT]) – Shared structural models.

  • occ_model (occupancies_kinetics) – Kinetic occupancy model whose forward() returns shape [n_states, n_timepoints].

  • timepoint_names (List[str]) – Names for each timepoint column (excluding dark).

  • dark_key (str) – Key for the dark entry.

  • verbose (int) – Verbosity level.

Return type:

ModelCollection

classmethod from_ihm(filepath, max_res=1.5, radius_angstrom=4.0, device=None, verbose=0)[source]

Load a ModelCollection from an IHM mmCIF file.

Requires the optional python-ihm dependency.

Parameters:
  • filepath (str) – Path to IHM mmCIF file.

  • max_res (float) – Maximum resolution for FFT grid setup.

  • radius_angstrom (float) – Radius for electron density calculation.

  • device (torch.device, optional) – Device for model tensors.

  • verbose (int) – Verbosity level.

Return type:

tuple of (ModelCollection, IHMEnsembleMapping)

write_ihm(filepath, mapping=None, datasets=None)[source]

Write this ModelCollection to IHM mmCIF format.

Requires the optional python-ihm dependency.

Parameters:
  • filepath (str) – Output file path.

  • mapping (IHMEnsembleMapping, optional) – Mapping with metadata for round-tripping. If None, a minimal mapping is created from the collection structure.

  • datasets (dict of str -> ReflectionData, optional) – Per-timepoint reflection data to embed in the CIF. Each key should match a timepoint name.

keys()[source]
values()[source]
items()[source]
get(name, default=None)[source]
property dark_key: str
property dark_model: _SharedMixedModel

Shortcut for self[dark_key].

property base_models: ModuleList

The shared structural models (owned by this collection).

property n_base_models: int
property timepoint_names: List[str]

All keys except the dark key.

property cell
property spacegroup
property device
get_all_fractions()[source]

Current fractions for each timepoint (including dark).

get_fractions_matrix()[source]

All fractions as a matrix [n_timepoints, n_models].

Rows are ordered by self._order (i.e. insertion order).

freeze_all_fractions()[source]

Freeze fractions at all timepoints.

unfreeze_all_fractions()[source]

Unfreeze fractions at all timepoints (except dark).

freeze_structures()[source]

Freeze xyz and adp on all base models.

unfreeze_structures()[source]

Unfreeze xyz and adp on all base models.

write_pdbs(outdir)[source]

Write each base model to a PDB file in outdir.

Files are named base_model_0.pdb, base_model_1.pdb, etc.

Parameters:

outdir (str) – Directory to write PDB files into (must exist).

Submodules