torchref.model.model module

A base model class for atomic structure models using PyTorch.

Space groups are stored as gemmi.SpaceGroup objects for consistency and direct access to symmetry operations.

Variable naming conventions: - adp: Atomic displacement parameters (model-level, replaces b_factor) - xyz: Cartesian coordinates - xyz_fractional: Fractional coordinates - F_calc/F_obs: Structure factor amplitudes (uppercase = amplitudes) - f_calc/f_obs: Complex structure factors (lowercase = complex)

class torchref.model.model.Model(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]

Bases: DeviceMixin, DebugMixin, Module

Base model class for atomic structure models using PyTorch.

This class provides the foundation for managing atomic structure data including coordinates, atomic displacement parameters (ADPs), and occupancies. It supports both empty initialization for state_dict loading and file-based initialization from PDB/CIF files.

Parameters:
  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.

  • verbose (int, optional) – Verbosity level for logging. Default is 1.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.

xyz

Atomic coordinates tensor with shape (n_atoms, 3).

Type:

MixedTensor

adp

Atomic displacement parameters (isotropic B-factors) with shape (n_atoms,).

Type:

PositiveMixedTensor

u

Anisotropic displacement parameters with shape (n_atoms, 6).

Type:

MixedTensor

occupancy

Atomic occupancies with values in [0, 1].

Type:

OccupancyTensor

pdb

DataFrame containing atomic model data.

Type:

pandas.DataFrame

cell

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

Type:

Cell

spacegroup

Space group object.

Type:

gemmi.SpaceGroup

symmetry

Symmetry operations handler for this space group.

Type:

Symmetry

initialized

Whether the model has been initialized with data.

Type:

bool

Examples

Empty initialization for state_dict loading:

model = Model()
model.load_state_dict(torch.load('model.pt'))

File-based initialization:

model = Model()
model.load_pdb('structure.pdb')
__init__(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]

Initialize an empty Model shell.

Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().

Parameters:
  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.

  • verbose (int, optional) – Verbosity level for logging. Default is 1.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.

__bool__()[source]

Return the initialization status when used in boolean context.

property exclude_H_from_sf: bool

Whether to exclude hydrogen atoms from structure factor calculation.

When True, H atoms are excluded from get_iso() / get_aniso() so they do not contribute to Fcalc. They still participate in geometry and VDW restraints. Default is False.

property cell: Cell | None

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

Returns:

The unit cell object, or None if not set.

Return type:

Cell or None

property spacegroup: SpaceGroup | None

Space group object.

Returns:

The space group object, or None if not set.

Return type:

gemmi.SpaceGroup or None

property symmetry: SpaceGroup | None

Symmetry operations handler for this space group.

Returns the same SpaceGroup object as self.spacegroup — the separate Symmetry wrapper was redundant since Symmetry is just an alias.

Returns:

The space group object, or None if not set.

Return type:

SpaceGroup or None

property inv_fractional_matrix: Tensor

Fractionalization matrix B^-1 (Cartesian -> fractional).

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) fractionalization matrix.

Return type:

torch.Tensor

property fractional_matrix: Tensor

Orthogonalization matrix B (fractional -> Cartesian).

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) orthogonalization matrix.

Return type:

torch.Tensor

property recB: Tensor

Reciprocal basis matrix with [a*, b*, c*] as rows.

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) matrix where rows are the reciprocal basis vectors.

Return type:

torch.Tensor

property Z: Tensor

Atomic numbers for all atoms.

Returns:

Tensor of atomic numbers with shape (n_atoms,).

Return type:

torch.Tensor

get_P1_parameters_iso()[source]

Get model parameters transformed to P1 space for optimization.

This is useful for optimizers that do not handle symmetry directly or MD.

Returns:

  • xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.

  • adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.

  • occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.

  • A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.

  • B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

get_MD_parameters()[source]

Get model parameters prepared for molecular dynamics simulation.

Returns all P1-expanded parameters plus atomic numbers for MD engines.

Returns:

  • xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.

  • adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.

  • occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.

  • A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.

  • B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.

  • Z_p1 (torch.Tensor) – Atomic numbers expanded to P1 space.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

property parametrization

(A, B)}.

The parametrization is built lazily on first access.

Returns:

Dictionary mapping element symbols to tuples of (A, B) tensors.

Return type:

dict

Type:

ITC92 parametrization dictionary {element

get_scattering_params_iso()[source]

Get ITC92 scattering parameters (A, B) for isotropic atoms.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_iso_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_iso_atoms, 5).

get_scattering_params_aniso()[source]

Get ITC92 scattering parameters (A, B) for anisotropic atoms.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_aniso_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_aniso_atoms, 5).

set_restraints_cif(cif_path)[source]

Set CIF path for lazy restraint building.

Parameters:
  • cif_path (str or list of str) – Path(s) to CIF restraints dictionary file(s).

  • self (return) – For method chaining

property restraints

Lazy restraints property.

The restraints are built on first access using the model’s pdb DataFrame and the CIF path set via set_restraints_cif().

Returns:

The restraints object containing bond, angle, torsion, etc. restraints.

Return type:

RestraintsNew

bond_deviations()[source]

Compute bond length deviations using current xyz coordinates.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.

  • sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.

angle_deviations()[source]

Compute angle deviations using current xyz coordinates.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected angles in radians.

  • sigmas (torch.Tensor) – Standard deviations in radians.

torsion_deviations_with_sigmas()[source]

Compute torsion deviations (wrapped for periodicity) and sigmas.

Returns:

  • deviations_rad (torch.Tensor) – Wrapped deviations in radians.

  • sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).

load(reader)[source]
load_pdb(file)[source]

Load atomic model from PDB file.

Parameters:

file (str) – Path to PDB file.

Returns:

Self, for method chaining.

Return type:

Model

load_cif(file)[source]

Load atomic model from mmCIF file.

Parameters:

file (str) – Path to CIF/mmCIF file.

Returns:

Self, for method chaining.

Return type:

Model

property chain_sequences: List[Tuple[str, str]]

Per-chain amino acid sequences as single-letter codes.

Excludes HETATM records. Gaps in residue numbering are filled with ?. Non-standard residues are mapped to X.

Returns:

Ordered list of (chain_id, sequence_string). E.g. [("A", "MKVL??GAST"), ("B", "ACDEFG")].

Return type:

list of (str, str)

get_chain_residues()[source]

Per-chain residue names as 3-letter codes (for IHM/CIF writing).

Excludes HETATM records. Unlike chain_sequences, returns the raw 3-letter codes without gap filling.

Returns:

Ordered list of (chain_id, [resname, ...]).

Return type:

list of (str, list of str)

update_pdb()[source]
get_vdw_radii()[source]

Get van der Waals radii for all atoms based on their elements.

Caches the result in self.vdw_radii for future calls.

Returns:

Van der Waals radii for each atom with shape (n_atoms,).

Return type:

torch.Tensor

to(*args, **kwargs)[source]

Move Model and rebuild device-specific SF indices.

Delegates to DeviceMixin, which walks self.__dict__ (picking up self.cell, self.altloc_pairs, self._restraints and all registered parameters / buffers), refreshes the self.device tracker, and invalidates caches. Afterwards this override rebuilds the precomputed SF indices on the new device.

copy()[source]

Create a deep copy of the Model.

Creates a complete independent copy including all registered buffers, module parameters, PDB DataFrame, and spacegroup information.

Returns:

A new Model instance with copied data.

Return type:

Model

Examples

model = Model().load_pdb('structure.pdb')
model_copy = model.copy()
# model_copy is independent, changes won't affect model
write_pdb(filename, metadata=None)[source]

Write model to PDB file with optional metadata header.

Parameters:
  • filename (str) – Output PDB file path.

  • metadata (RefinementMetadata, optional) – Metadata to render as PDB header (REMARK 3, TITLE, etc.).

write_cif(filename, metadata=None)[source]

Write model to mmCIF file with optional metadata.

Parameters:
  • filename (str) – Output mmCIF file path.

  • metadata (RefinementMetadata, optional) – Metadata to include (refinement statistics, title, etc.).

get_iso()[source]

Return per-atom parameters for the isotropic atom subset.

Selects atoms whose ADP is a single scalar b (i.e. not anisotropic). The subset is defined by ~self.aniso_flag — intersected with self._heavy_atom_mask when _exclude_H_from_sf is enabled — and is precomputed as self._iso_indices at init / whenever the mask changes.

Returns:

  • xyz (torch.Tensor, shape (n_iso, 3)) – Cartesian coordinates of the isotropic atoms (Å).

  • adp (torch.Tensor, shape (n_iso,)) – Isotropic B-factors (Ų).

  • occupancy (torch.Tensor, shape (n_iso,)) – Occupancies in [0, 1].

Notes

When every atom is isotropic and no H exclusion is active — self._iso_covers_all is True, the common protein-refinement case — the per-atom indexing is skipped and self.xyz(), self.adp(), self.occupancy() are returned directly.

Motivation: self.xyz()[idx] is a no-op forward when idx = arange(N), but its backward routes through PyTorch’s aten::_index_put_impl_(accumulate=True), which performs a cub::DeviceRadixSortOnesweepKernel over len(idx) indices followed by a deduplicated scatter (~50-150 µs/iter per gather on A100 / 1DAW). Skipping the gather avoids that cost.

set_default_masks()[source]
PARAM_TYPES: Tuple[str, ...] = ('xyz', 'adp', 'u', 'occupancy')
parameters_of_types(types)[source]

Return the leaf ``nn.Parameter``s for the named parameter types.

Used by refinement entry points (refine_xyz, refine_adp, …) to construct an optimizer over only the leaves the caller intends to update. LossState.step then uses the optimizer’s param groups as intent and disables requires_grad on any other leaves the loss also touches.

Parameters:

types (Iterable[str]) – Subset of Model.PARAM_TYPES: "xyz", "adp", "u", "occupancy". Unknown names are silently skipped.

Returns:

The refinable_params leaf for each requested type, in the order the types were given.

Return type:

list of nn.Parameter

freeze(target)[source]
freeze_all()[source]
unfreeze_all()[source]
unfreeze(target)[source]
update_mask_from_selection(selection_string, target, mode='set', freeze=True)[source]

Update the refinable mask for a parameter using Phenix-style selection syntax.

This method updates the internal mask buffer (xyz_mask, adp_mask, u_mask, or occupancy_mask) based on the selection. The updated mask is NOT automatically applied to the parameter tensors - use apply_mask_to_parameter() to apply it.

Parameters:
  • selection_string (str) – Phenix-style selection string (see parse_phenix_selection docs).

  • target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.

  • mode (str, optional) – How to combine with current mask: - ‘set’: Replace mask with selection (default) - ‘add’: Add selection to current mask - ‘remove’: Remove selection from current mask

  • freeze (bool, optional) – If True (default), selected atoms will be frozen (mask=False). If False, selected atoms will be unfrozen (mask=True).

Raises:

ValueError – If target is not recognized or selection syntax is invalid.

Examples

# Freeze chain A coordinates
model.update_mask_from_selection("chain A", "xyz", mode='set', freeze=True)
model.apply_mask_to_parameter("xyz")

# Unfreeze backbone atoms
model.update_mask_from_selection("name CA or name C or name N", "xyz", freeze=False)
model.apply_mask_to_parameter("xyz")
apply_mask_to_parameter(target)[source]

Apply the current mask buffer to the parameter tensor.

Takes the current state of the mask buffer (xyz_mask, adp_mask, etc.) and applies it to the corresponding parameter tensor’s refinable mask.

Parameters:

target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.

Raises:

ValueError – If target is not recognized.

Examples

model.update_mask_from_selection("chain A", "xyz", freeze=True)
model.apply_mask_to_parameter("xyz")
freeze_selection(selection_string, targets='all')[source]

Freeze atoms matching a Phenix-style selection for specified parameters.

Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • targets (str or list of str, optional) – Parameter(s) to freeze. Can be: - ‘all’: Freeze xyz, adp, u, and occupancy (default) - str: Single parameter (‘xyz’, ‘adp’, ‘u’, ‘occupancy’) - list: List of parameters, e.g., [‘xyz’, ‘adp’]

Examples

# Freeze all parameters for chain A
model.freeze_selection("chain A", targets='all')

# Freeze only coordinates for residues 10-20
model.freeze_selection("resseq 10:20", targets='xyz')
unfreeze_selection(selection_string, targets='all')[source]

Unfreeze atoms matching a Phenix-style selection for specified parameters.

Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • targets (str or list of str, optional) – Parameter(s) to unfreeze. Can be: - ‘all’: Unfreeze xyz, adp, u, and occupancy (default) - str: Single parameter (‘xyz’, ‘adp’, ‘u’, ‘occupancy’) - list: List of parameters, e.g., [‘xyz’, ‘adp’]

Examples

# Unfreeze all parameters for chain A
model.unfreeze_selection("chain A", targets='all')

# Unfreeze only coordinates for backbone atoms
model.unfreeze_selection("name CA or name C or name N", targets='xyz')
get_aniso()[source]

Return per-atom parameters for the anisotropic atom subset.

Selects atoms whose ADP is the 6-element anisotropic tensor u = (u11, u22, u33, u12, u13, u23). The subset is defined by self.aniso_flag — intersected with self._heavy_atom_mask when _exclude_H_from_sf is enabled — and is precomputed as self._aniso_indices at init / whenever the mask changes.

Returns:

  • xyz (torch.Tensor, shape (n_aniso, 3)) – Cartesian coordinates of the anisotropic atoms (Å). Empty tensor when there are no anisotropic atoms.

  • u (torch.Tensor, shape (n_aniso, 6)) – Anisotropic U components (Ų) in the order (u11, u22, u33, u12, u13, u23). Empty when n_aniso == 0.

  • occupancy (torch.Tensor, shape (n_aniso,)) – Occupancies in [0, 1]. Empty when n_aniso == 0.

Notes

When there are no anisotropic atoms — self._aniso_is_empty is True, the common protein-refinement case — three empty placeholder tensors are returned without calling the MixedTensors at all. This avoids both the wrapped forward .clone() and the slow aten::_index_put_impl_ backward path that the self.xyz()[idx] gather would otherwise generate (see get_iso() for the same rationale).

parameters(recurse=True)[source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
named_mixed_tensors()[source]

Iterate over all MixedTensor attributes with their names.

Yields:

Tuple of (name, MixedTensor)

print_parameters_info()[source]

Print information about all MixedTensor parameters.

register_alternative_conformations()[source]

Identify and register all alternative conformation groups in the structure.

For each residue that has alternative conformations (altloc A, B, C, etc.), this method identifies all atoms belonging to each conformation and stores their indices as tensors in a tuple.

The result is stored in self.altloc_pairs as a list of tuples, where each tuple contains tensors of atom indices for each alternative conformation.

Examples

For a residue with conformations A and B:

# Conformation A has atoms at indices [100, 101, 102, ...]
# Conformation B has atoms at indices [110, 111, 112, ...]
# Result: [(tensor([100, 101, 102, ...]), tensor([110, 111, 112, ...])), ...]

For a residue with conformations A, B, C:

# Result: [(tensor([200, 201, ...]), tensor([210, 211, ...]), tensor([220, 221, ...])), ...]
shake_coords(stddev)[source]

Apply random Gaussian noise to atomic coordinates.

Perturbs the atomic coordinates by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.

Parameters:

stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstroms.

shake_adp(stddev)[source]

Apply random Gaussian noise to ADPs (atomic displacement parameters).

Perturbs the ADPs by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.

Parameters:

stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstrom^2.

generate_hydrogens(mon_lib_path=None)[source]

Generate hydrogen atoms for the current model using gemmi.

Places hydrogens at ideal geometry using the CCP4 monomer library and gemmi’s topology engine. Returns a new Model instance with hydrogens added; the original model is not modified.

Parameters:

mon_lib_path (str, optional) – Path to CCP4 monomer library directory. If None, uses the monomer library bundled with torchref (covers standard amino acids and common small molecules).

Returns:

A new Model instance with hydrogen atoms added (strip_H=False). Unknown residues are skipped silently.

Return type:

Model

Notes

Requires gemmi (already a torchref dependency). Heavy-atom coordinates from the current model state are used, so call this after any coordinate changes you want reflected in the H positions.

Examples

>>> model_no_h = Model().load_pdb('structure.pdb')
>>> model_with_h = model_no_h.generate_hydrogens()
>>> print(model_with_h.Z.shape)   # more atoms than model_no_h
strip_altlocs()[source]

Return a new model with alternate conformations removed.

For each residue that has multiple altlocs, the conformer with highest average occupancy is kept (ties broken alphabetically). The altloc column is cleared to "" in the returned model. The original model is not modified.

strip_hydrogens()[source]

Return a new model with hydrogen atoms removed.

The returned model has consistent DataFrame and tensors (xyz, adp, occupancy) with H atoms excluded. The original model is not modified.

Returns:

New model without hydrogen atoms.

Return type:

Model

hydrogenate(verbose=0, optimize=False, lbfgs_steps=3, max_iter=20)[source]

Return a new model with hydrogen atoms placed via Kabsch alignment.

Uses torchref’s monomer library to identify missing H atoms, places them by SVD-aligning ideal monomer coordinates onto the current model coordinates, then corrects each H to sit at ideal bond length from its parent atom. The original model is not modified.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=summary, 2=detailed). Default 0.

  • optimize (bool, optional) – If True, run a short LBFGS geometry optimization on H positions after placement. Default False (Kabsch placement only).

  • lbfgs_steps (int, optional) – Number of LBFGS outer steps (only when optimize=True). Default 3.

  • max_iter (int, optional) – Max line-search iterations per LBFGS step. Default 20.

Returns:

New model with hydrogen atoms added. All parameters are unfrozen in the returned model.

Return type:

Model

adp_loss()[source]

Compute the ADP regularization loss.

This loss encourages ADPs to have similar values across the structure, helping to prevent overfitting during refinement.

Returns:

Scalar tensor representing the ADP loss.

Return type:

torch.Tensor

adp_nll_loss(target_log_std=0.2)[source]

Compute negative log-likelihood of ADPs assuming Gaussian distribution in log-space.

This regularization penalizes ADPs that deviate from a target distribution with a FIXED standard deviation (hyperparameter), avoiding circular dependency on the current distribution’s statistics.

The NLL for a Gaussian distribution in log-space is:

NLL = 0.5 * mean[(log_adp - mu)^2 / sigma^2 + log(2*pi*sigma^2)]

Where mu is the mean of log-space ADPs (computed from current data) and sigma is the FIXED target standard deviation (hyperparameter).

Parameters:

target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. - 0.1 = very tight (ADPs within ~10% of mean) - 0.2 = moderate spread (ADPs within ~20% of mean) [RECOMMENDED] - 0.3 = looser spread (ADPs within ~30% of mean)

Returns:

Scalar tensor representing the NLL. Lower values indicate the distribution is closer to the target Gaussian with fixed sigma.

Return type:

torch.Tensor

Examples

# During refinement
structure_factor_loss = compute_structure_factor_loss()
nll_reg = model.adp_nll_loss(target_log_std=0.2)
total_loss = structure_factor_loss + 0.01 * nll_reg
total_loss.backward()

Notes

Uses FIXED sigma (no circular dependency on current distribution). Smaller target_log_std = stronger regularization (tighter distribution).

adp_nll_loss_per_atom(target_log_std=0.2)[source]

Compute per-atom negative log-likelihood for ADPs in log-space.

Returns the NLL contribution for each individual atom, useful for identifying outliers or applying atom-specific regularization weights.

The per-atom NLL is:

NLL_i = 0.5 * [(log_adp_i - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
Parameters:

target_log_std (float, optional) – Fixed target standard deviation in log-space. Default is 0.2.

Returns:

Tensor of shape (n_atoms,) with per-atom NLL values. Higher values indicate atoms farther from the mean.

Return type:

torch.Tensor

Examples

# Get per-atom NLL
atom_nll = model.adp_nll_loss_per_atom(target_log_std=0.2)
# Identify outlier atoms (high NLL)
threshold = atom_nll.mean() + 2 * atom_nll.std()
outliers = atom_nll > threshold
adp_kl_divergence_loss(target_log_std=0.2)[source]

Compute KL divergence between log ADP distribution and target Gaussian.

Measures how different the current log ADP distribution is from a target Gaussian distribution with the current mean of log ADPs and a fixed target standard deviation.

KL divergence formula for two Gaussians with same mean:

KL(q || p) = log(sigma_target/sigma_data) + sigma_data^2 / (2*sigma_target^2) - 0.5
Parameters:

target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. Controls how tightly ADPs should cluster.

Returns:

Scalar KL divergence value (always >= 0). 0 means distributions match perfectly. Higher values mean more deviation from target.

Return type:

torch.Tensor

Examples

# Use in loss function
loss = xray_loss + w_adp * model.adp_kl_divergence_loss(0.2)

Notes

Lower target_log_std = stronger regularization (tighter distribution). Mean is detached so it adapts to the natural scale of the data.

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the Model.

Includes all registered buffers, model parameters (xyz, b, u, occupancy), PDB DataFrame, and metadata (spacegroup, device, dtype, etc.).

Parameters:
  • destination (dict, optional) – Optional dict to populate with state.

  • prefix (str, optional) – Prefix for parameter names. Default is ‘’.

  • keep_vars (bool, optional) – Whether to keep variables in computational graph. Default is False.

Returns:

Complete state dictionary.

Return type:

dict

save_state(path)[source]

Save the complete state of the model to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the model from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, optional) – Whether to strictly enforce that keys match. Default is True.

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]

Create a fully initialized Model from a state dictionary.

This is the recommended way to restore a Model from a saved state. Creates an instance with properly initialized submodules, then loads the state.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • dtype_float (torch.dtype, optional) – Float dtype for tensors. Defaults to the configured dtypes.float.

Returns:

Fully initialized instance with restored state.

Return type:

Model

get_selection_mask(selection)[source]

Return a boolean mask for atoms matching a Phenix-style selection.

This is a convenience method that wraps parse_phenix_selection() to return a mask that can be used directly with MixedTensor.set() or other operations requiring atom selection.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

Boolean tensor of shape (n_atoms,) where True indicates selected atoms.

Return type:

torch.Tensor

Raises:

Examples

model = Model().load_pdb('structure.pdb')
# Get mask for chain A
mask = model.get_selection_mask("chain A")
# Use mask to update coordinates
new_coords = model.xyz()[mask] + translation
model.xyz.set(new_coords, mask)
# Get mask for backbone atoms
backbone_mask = model.get_selection_mask("name CA or name C or name N or name O")
# Complex selection with parentheses
mask = model.get_selection_mask("chain A and (resname ALA or resname GLY)")
select(selection)[source]

Return a new Model containing only atoms matching the Phenix-style selection.

Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.

Return type:

Model

Raises:
  • RuntimeError – If the model has not been initialized.

  • ValueError – If selection syntax is invalid or no atoms are selected.

Examples

model = Model().load_pdb('structure.pdb')
# Select chain A
chain_a = model.select("chain A")
# Select backbone atoms
backbone = model.select("name CA or name C or name N or name O")
# Select residues 10-50 of chain B
region = model.select("chain B and resseq 10:50")
# Select all except water
no_water = model.select("not resname HOH")
# Complex selection with parentheses
complex_sel = model.select("chain A and (resname ALA or resname GLY)")

Notes

This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.

xyz_fractional()[source]

Return atomic coordinates in fractional space.

Converts Cartesian coordinates to fractional coordinates using the inverse fractional matrix.

Returns:

Tensor of shape (n_atoms, 3) with fractional coordinates.

Return type:

torch.Tensor

rotate(rotation_matrix, center=None)[source]

Apply rotation to atomic coordinates (in-place).

Rotates all atoms around a specified center point. The rotation is applied using the formula: xyz_new = R @ (xyz - center) + center

Parameters:
  • rotation_matrix (torch.Tensor) – 3x3 rotation matrix. Should be orthogonal (R^T @ R = I).

  • center (torch.Tensor, optional) – Center of rotation with shape (3,). If None, uses the centroid of all atomic coordinates.

Returns:

Self, for method chaining.

Return type:

Model

Examples

# Rotate 90 degrees around Z-axis
import math
angle = math.pi / 2
R = torch.tensor([
    [math.cos(angle), -math.sin(angle), 0],
    [math.sin(angle), math.cos(angle), 0],
    [0, 0, 1]
])
model.rotate(R)

# Rotate around a specific point
center = torch.tensor([10.0, 20.0, 30.0])
model.rotate(R, center=center)
translate(translation, fractional=False)[source]

Apply translation to atomic coordinates (in-place).

Translates all atoms by a specified vector. The translation can be given in either Cartesian or fractional coordinates.

Parameters:
  • translation (torch.Tensor) – Translation vector with shape (3,).

  • fractional (bool, optional) – If True, the translation is interpreted as fractional coordinates and converted to Cartesian before applying. Default is False (translation is in Cartesian Angstroms).

Returns:

Self, for method chaining.

Return type:

Model

Examples

# Translate by 5 Angstroms along X
model.translate(torch.tensor([5.0, 0.0, 0.0]))

# Translate by half a unit cell along each axis
model.translate(torch.tensor([0.5, 0.5, 0.5]), fractional=True)
get_centroid()[source]

Compute the centroid (center of mass) of all atoms.

Returns:

Centroid coordinates with shape (3,).

Return type:

torch.Tensor

use_internal_coordinates(n_aa_per_segment=5, bond_cutoff=2.0, cif_dict=None, requires_grad=True)[source]

Switch xyz to segmented internal coordinate parametrization.

Replaces the current xyz MixedTensor with a SegmentedInternalCoordinateTensor that parametrizes atomic positions using bond lengths, angles, torsion angles, and per-segment rigid body parameters. The molecule is broken into independent segments to avoid the “lever arm problem” where small torsion changes near the root cause large displacements at distant atoms.

Parameters:
  • n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 5. - Smaller values (1-2): More segments, shallower trees, less lever arm - Larger values (5-10): Fewer segments, deeper trees, more lever arm

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0. Only used when cif_dict is not provided.

  • cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] DataFrame with ‘atom1’, ‘atom2’.

  • requires_grad (bool, optional) – Whether internal coordinate parameters should have gradients. Default is True.

Returns:

Self, for method chaining.

Return type:

Model

Examples

model = Model()
model.load_pdb('structure.pdb')
model.use_internal_coordinates(n_aa_per_segment=3)

# Now model.xyz() returns coordinates reconstructed from
# segmented internal coordinates

# Shake the structure using internal coordinates
new_xyz = model.xyz.shake(magnitude=0.1)

# Each segment has independent internal coordinates and
# rigid body parameters (position + orientation)

Notes

After calling this method, model.xyz will be a SegmentedInternalCoordinateTensor instead of a MixedTensor. This provides: - Shallow spanning trees within segments (depth ~10-30 vs ~1000) - Independent segments that don’t propagate changes to distant atoms - Rigid body parameters (position + orientation) per segment - forward() / __call__(): Reconstruct Cartesian coordinates - shake(magnitude): Add noise to internal parameters - Gradient flow through all internal coordinate parameters