torchref.restraints.restraints module

Restraints Class (Refactored) for Crystallographic Model Refinement

This module provides a refactored restraints handler using the builder pattern. It maintains the same interface as the original Restraints class but uses the more efficient and testable builder classes internally.

Key improvements: - Single-pass iteration over residues (vs multiple passes in original) - Pre-grouped residue data for O(N log N) vs O(N×R) complexity - Sorted indices for cache-friendly tensor access - Separated builder classes for easier testing and maintenance - Decoupled from Model: accepts pdb DataFrame and callable functions for xyz/adp

class torchref.restraints.restraints.RestraintsNew(pdb=None, cif_path=None, xyz_fn=None, adp_fn=None, vdw_radii_fn=None, cell=None, spacegroup=None, links=None, verbose=1)[source]

Bases: DeviceMixin, DebugMixin, Module

Refactored restraints handler for crystallographic model refinement.

This class uses the builder pattern internally for efficient construction of restraint tensors. It is decoupled from Model and accepts a pdb DataFrame with callable functions for accessing coordinates and ADPs.

Parameters:
  • pdb (pd.DataFrame, optional) – DataFrame containing atomic structure data. If None, creates empty shell.

  • cif_path (str or list of str, optional) – Path to the CIF restraints dictionary file(s).

  • xyz_fn (callable, optional) – Function returning current xyz coordinates as torch.Tensor. Required for building and evaluation if pdb is provided.

  • adp_fn (callable, optional) – Function returning current ADP values as torch.Tensor. Required for ADP-based restraints.

  • vdw_radii_fn (callable, optional) – Function returning VDW radii as torch.Tensor. Required for VDW restraints.

  • verbose (int, default 1) – Verbosity level (0=silent, 1=normal, 2=detailed).

pdb

DataFrame containing atomic structure data.

Type:

pd.DataFrame

xyz_fn

Function returning current xyz coordinates.

Type:

callable

adp_fn

Function returning current ADP values.

Type:

callable

vdw_radii_fn

Function returning VDW radii.

Type:

callable

cif_dict

Parsed CIF dictionary with restraints for each residue type.

Type:

dict

restraints

Hierarchical dictionary containing all restraints.

Type:

dict

__init__(pdb=None, cif_path=None, xyz_fn=None, adp_fn=None, vdw_radii_fn=None, cell=None, spacegroup=None, links=None, verbose=1)[source]

Initialize the Restraints handler.

xyz(xyz=None)[source]

Get current xyz coordinates.

Parameters:

xyz (torch.Tensor, optional) – If provided, returns this tensor directly. Otherwise calls the stored xyz_fn callable.

Returns:

Current xyz coordinates of shape (n_atoms, 3).

Return type:

torch.Tensor

adp(adp=None)[source]

Get current ADP values.

Parameters:

adp (torch.Tensor, optional) – If provided, returns this tensor directly. Otherwise calls the stored adp_fn callable.

Returns:

Current ADP values of shape (n_atoms,).

Return type:

torch.Tensor

get_vdw_radii()[source]

Get VDW radii for all atoms.

Returns:

VDW radii of shape (n_atoms,).

Return type:

torch.Tensor

property restraints: _RestraintsAccessor

Provide dict-like access to restraints for backward compatibility.

Returns an accessor object that mimics the old nested dict interface.

expand_altloc(residue)[source]

Expand residue with alternative conformations into separate conformations.

Yields one DataFrame per altloc (with common atoms included in each).

build_restraints()[source]

Build all restraints using the fast builder API.

This method uses the optimized builders that handle all residues internally with Numba-accelerated matching (~10x faster).

property h_topo

Access riding hydrogen topology (None if not built).

property h_excl_hash

Access H-specific exclusion hash tensor (None if not built).

rebuild_vdw_restraints()[source]

Refresh the VDW pair list using the cached build kwargs.

Called by NonBondedTarget.maintenance() after it detects that the maximum atomic displacement since the last build has exceeded the rebuild threshold. Uses the same cutoff, sigma, inter_residue_only and use_spatial_hash that the initial build was given, so behaviour is stable across the run.

summary()[source]

Print a detailed summary of all restraints.

__repr__()[source]

Return string representation.

bond_lengths(idx, xyz=None)[source]

Compute current bond lengths from atomic coordinates.

Parameters:
  • idx (torch.Tensor) – Bond indices tensor of shape (N, 2).

  • xyz (torch.Tensor, optional) – Coordinates tensor of shape (n_atoms, 3). If None, uses the stored xyz_fn callable.

Returns:

Tensor of bond lengths of shape (N,).

Return type:

torch.Tensor

copy()[source]

Create a deep copy of the Restraints object.

Returns:

A deep copy of this Restraints instance.

Return type:

Restraints

bond_deviations(xyz=None)[source]

Compute bond length deviations and sigmas.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.

  • sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.

nll_bonds(xyz=None)[source]

Compute negative log-likelihood for bond length restraints.

For Gaussian distribution: NLL = -log(P(x|μ,σ)) NLL = 0.5 * ((x - μ) / σ)^2 + log(σ) + 0.5 * log(2π)

This is the true NLL where exp(-NLL) = probability density.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_bonds,) with negative log-likelihood values.

Return type:

torch.Tensor

angles(idx, xyz=None)[source]

Compute current angle values for all angle restraints.

Parameters:
  • idx (torch.Tensor) – Angle indices tensor of shape (N, 3).

  • xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_angles,) with current angle values in degrees.

Return type:

torch.Tensor

angle_deviations(xyz=None)[source]

Compute angle deviations and sigmas.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected angles in radians.

  • sigmas (torch.Tensor) – Standard deviations in radians.

nll_angles(xyz=None)[source]

Compute negative log-likelihood for angle restraints.

For Gaussian distribution: NLL = -log(P(x|μ,σ)) NLL = 0.5 * ((x - μ) / σ)^2 + log(σ) + 0.5 * log(2π)

This is the true NLL where exp(-NLL) = probability density.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_angles,) with negative log-likelihood values.

Return type:

torch.Tensor

cat_dict()[source]

Concatenate all restraint dictionaries into ‘all’ keys.

Creates restraints[‘bond’][‘all’], restraints[‘angle’][‘all’], and restraints[‘torsion’][‘all’] by concatenating all origins.

torsions(idx, xyz=None)[source]

Compute current torsion angle values for all torsion restraints.

Parameters:
  • idx (torch.Tensor) – Torsion indices tensor of shape (N, 4).

  • xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_torsions,) with current torsion values in degrees.

Return type:

torch.Tensor

torsion_deviations(xyz=None, wrapped=True)[source]

Compute deviations between calculated and expected torsion angles.

Parameters:
  • xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

  • wrapped (bool, default True) – If True, wrap deviations accounting for periodicity. If False, return raw deviations (calculated - expected).

Returns:

Tensor of shape (n_torsions,) with deviations in degrees. For wrapped=True, deviations are in range appropriate for the period.

Return type:

torch.Tensor

Notes

Expected values from CIF library are discrete (typically -60°, 0°, 60°, 90°, 180°) while calculated values from structure are continuous. This is correct! Use wrapped=True for meaningful comparison and visualization.

torsion_deviations_with_sigmas(xyz=None)[source]

Compute torsion deviations (wrapped for periodicity) and sigmas.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

  • deviations_rad (torch.Tensor) – Wrapped deviations in radians.

  • sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).

nll_torsions(xyz=None)[source]

Compute negative log-likelihood for torsion angle restraints.

For von Mises distribution: NLL = -log(P(θ|μ,κ)) NLL = -κ*cos(θ-μ) + log(I₀(κ)) + log(2π)

where κ = 1/σ² is the concentration parameter and I₀ is the modified Bessel function of the first kind.

Notes

Period indicates n-fold rotational symmetry (e.g., period=6 for benzene). We handle this by finding the minimum angular distance considering periodicity. For period=n, angles differing by 360°/n are equivalent.

This is the true NLL where exp(-NLL) = probability density.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_torsions,) with negative log-likelihood values.

Return type:

torch.Tensor

nll_planes(xyz=None)[source]

Compute negative log-likelihood for plane restraints.

For each plane, computes the RMSD of atom deviations from the best-fit plane. Uses Gaussian NLL: NLL = 0.5 * (deviation / σ)² + log(σ) + 0.5 * log(2π)

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_planes,) with negative log-likelihood values.

Return type:

torch.Tensor

nll_vdw(xyz=None)[source]

Compute negative log-likelihood for VDW (non-bonded) restraints.

Uses a soft-repulsive potential based on distance violations. NLL = 0.5 * (max(0, min_dist - actual_dist) / σ)² + log(σ) + 0.5 * log(2π)

Only violations (distances shorter than minimum) contribute to the loss.

Parameters:

xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.

Returns:

Tensor of shape (n_pairs,) with negative log-likelihood values.

Return type:

torch.Tensor

adp_b_differences(adp=None)[source]

Compute B-factor differences between bonded atoms.

Parameters:

adp (torch.Tensor, optional) – ADP values. If None, uses the stored adp_fn callable.

Returns:

Tensor of B-factor differences (B_i - B_j) for all bonds.

Return type:

torch.Tensor

adp_similarity_loss(adp=None, sigma=2.0)[source]

Compute ADP similarity loss (SIMU in Phenix/SHELX).

This restrains the B-factors of bonded atoms to be similar. Loss = Σ ((B_i - B_j) / sigma)^2

Parameters:
  • adp (torch.Tensor, optional) – ADP values. If None, uses the stored adp_fn callable.

  • sigma (float, default 2.0) – Target standard deviation for B-factor differences in Ų.

Returns:

Mean similarity loss.

Return type:

torch.Tensor