torchref.restraints.restraints module
Restraints Class (Refactored) for Crystallographic Model Refinement
This module provides a refactored restraints handler using the builder pattern. It maintains the same interface as the original Restraints class but uses the more efficient and testable builder classes internally.
Key improvements: - Single-pass iteration over residues (vs multiple passes in original) - Pre-grouped residue data for O(N log N) vs O(N×R) complexity - Sorted indices for cache-friendly tensor access - Separated builder classes for easier testing and maintenance - Decoupled from Model: accepts pdb DataFrame and callable functions for xyz/adp
- class torchref.restraints.restraints.RestraintsNew(pdb=None, cif_path=None, xyz_fn=None, adp_fn=None, vdw_radii_fn=None, cell=None, spacegroup=None, links=None, verbose=1)[source]
Bases:
DeviceMixin,DebugMixin,ModuleRefactored restraints handler for crystallographic model refinement.
This class uses the builder pattern internally for efficient construction of restraint tensors. It is decoupled from Model and accepts a pdb DataFrame with callable functions for accessing coordinates and ADPs.
- Parameters:
pdb (pd.DataFrame, optional) – DataFrame containing atomic structure data. If None, creates empty shell.
cif_path (str or list of str, optional) – Path to the CIF restraints dictionary file(s).
xyz_fn (callable, optional) – Function returning current xyz coordinates as torch.Tensor. Required for building and evaluation if pdb is provided.
adp_fn (callable, optional) – Function returning current ADP values as torch.Tensor. Required for ADP-based restraints.
vdw_radii_fn (callable, optional) – Function returning VDW radii as torch.Tensor. Required for VDW restraints.
verbose (int, default 1) – Verbosity level (0=silent, 1=normal, 2=detailed).
- pdb
DataFrame containing atomic structure data.
- Type:
pd.DataFrame
- xyz_fn
Function returning current xyz coordinates.
- Type:
callable
- adp_fn
Function returning current ADP values.
- Type:
callable
- vdw_radii_fn
Function returning VDW radii.
- Type:
callable
- __init__(pdb=None, cif_path=None, xyz_fn=None, adp_fn=None, vdw_radii_fn=None, cell=None, spacegroup=None, links=None, verbose=1)[source]
Initialize the Restraints handler.
- xyz(xyz=None)[source]
Get current xyz coordinates.
- Parameters:
xyz (torch.Tensor, optional) – If provided, returns this tensor directly. Otherwise calls the stored xyz_fn callable.
- Returns:
Current xyz coordinates of shape (n_atoms, 3).
- Return type:
- adp(adp=None)[source]
Get current ADP values.
- Parameters:
adp (torch.Tensor, optional) – If provided, returns this tensor directly. Otherwise calls the stored adp_fn callable.
- Returns:
Current ADP values of shape (n_atoms,).
- Return type:
- get_vdw_radii()[source]
Get VDW radii for all atoms.
- Returns:
VDW radii of shape (n_atoms,).
- Return type:
- property restraints: _RestraintsAccessor
Provide dict-like access to restraints for backward compatibility.
Returns an accessor object that mimics the old nested dict interface.
- expand_altloc(residue)[source]
Expand residue with alternative conformations into separate conformations.
Yields one DataFrame per altloc (with common atoms included in each).
- build_restraints()[source]
Build all restraints using the fast builder API.
This method uses the optimized builders that handle all residues internally with Numba-accelerated matching (~10x faster).
- property h_topo
Access riding hydrogen topology (None if not built).
- property h_excl_hash
Access H-specific exclusion hash tensor (None if not built).
- rebuild_vdw_restraints()[source]
Refresh the VDW pair list using the cached build kwargs.
Called by
NonBondedTarget.maintenance()after it detects that the maximum atomic displacement since the last build has exceeded the rebuild threshold. Uses the samecutoff,sigma,inter_residue_onlyanduse_spatial_hashthat the initial build was given, so behaviour is stable across the run.
- bond_lengths(idx, xyz=None)[source]
Compute current bond lengths from atomic coordinates.
- Parameters:
idx (torch.Tensor) – Bond indices tensor of shape (N, 2).
xyz (torch.Tensor, optional) – Coordinates tensor of shape (n_atoms, 3). If None, uses the stored xyz_fn callable.
- Returns:
Tensor of bond lengths of shape (N,).
- Return type:
- copy()[source]
Create a deep copy of the Restraints object.
- Returns:
A deep copy of this Restraints instance.
- Return type:
- bond_deviations(xyz=None)[source]
Compute bond length deviations and sigmas.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.
sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.
- nll_bonds(xyz=None)[source]
Compute negative log-likelihood for bond length restraints.
For Gaussian distribution: NLL = -log(P(x|μ,σ)) NLL = 0.5 * ((x - μ) / σ)^2 + log(σ) + 0.5 * log(2π)
This is the true NLL where exp(-NLL) = probability density.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_bonds,) with negative log-likelihood values.
- Return type:
- angles(idx, xyz=None)[source]
Compute current angle values for all angle restraints.
- Parameters:
idx (torch.Tensor) – Angle indices tensor of shape (N, 3).
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_angles,) with current angle values in degrees.
- Return type:
- angle_deviations(xyz=None)[source]
Compute angle deviations and sigmas.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
deviations (torch.Tensor) – Calculated minus expected angles in radians.
sigmas (torch.Tensor) – Standard deviations in radians.
- nll_angles(xyz=None)[source]
Compute negative log-likelihood for angle restraints.
For Gaussian distribution: NLL = -log(P(x|μ,σ)) NLL = 0.5 * ((x - μ) / σ)^2 + log(σ) + 0.5 * log(2π)
This is the true NLL where exp(-NLL) = probability density.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_angles,) with negative log-likelihood values.
- Return type:
- cat_dict()[source]
Concatenate all restraint dictionaries into ‘all’ keys.
Creates restraints[‘bond’][‘all’], restraints[‘angle’][‘all’], and restraints[‘torsion’][‘all’] by concatenating all origins.
- torsions(idx, xyz=None)[source]
Compute current torsion angle values for all torsion restraints.
- Parameters:
idx (torch.Tensor) – Torsion indices tensor of shape (N, 4).
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_torsions,) with current torsion values in degrees.
- Return type:
- torsion_deviations(xyz=None, wrapped=True)[source]
Compute deviations between calculated and expected torsion angles.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
wrapped (bool, default True) – If True, wrap deviations accounting for periodicity. If False, return raw deviations (calculated - expected).
- Returns:
Tensor of shape (n_torsions,) with deviations in degrees. For wrapped=True, deviations are in range appropriate for the period.
- Return type:
Notes
Expected values from CIF library are discrete (typically -60°, 0°, 60°, 90°, 180°) while calculated values from structure are continuous. This is correct! Use wrapped=True for meaningful comparison and visualization.
- torsion_deviations_with_sigmas(xyz=None)[source]
Compute torsion deviations (wrapped for periodicity) and sigmas.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
deviations_rad (torch.Tensor) – Wrapped deviations in radians.
sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).
- nll_torsions(xyz=None)[source]
Compute negative log-likelihood for torsion angle restraints.
For von Mises distribution: NLL = -log(P(θ|μ,κ)) NLL = -κ*cos(θ-μ) + log(I₀(κ)) + log(2π)
where κ = 1/σ² is the concentration parameter and I₀ is the modified Bessel function of the first kind.
Notes
Period indicates n-fold rotational symmetry (e.g., period=6 for benzene). We handle this by finding the minimum angular distance considering periodicity. For period=n, angles differing by 360°/n are equivalent.
This is the true NLL where exp(-NLL) = probability density.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_torsions,) with negative log-likelihood values.
- Return type:
- nll_planes(xyz=None)[source]
Compute negative log-likelihood for plane restraints.
For each plane, computes the RMSD of atom deviations from the best-fit plane. Uses Gaussian NLL: NLL = 0.5 * (deviation / σ)² + log(σ) + 0.5 * log(2π)
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_planes,) with negative log-likelihood values.
- Return type:
- nll_vdw(xyz=None)[source]
Compute negative log-likelihood for VDW (non-bonded) restraints.
Uses a soft-repulsive potential based on distance violations. NLL = 0.5 * (max(0, min_dist - actual_dist) / σ)² + log(σ) + 0.5 * log(2π)
Only violations (distances shorter than minimum) contribute to the loss.
- Parameters:
xyz (torch.Tensor, optional) – Coordinates tensor. If None, uses the stored xyz_fn callable.
- Returns:
Tensor of shape (n_pairs,) with negative log-likelihood values.
- Return type:
- adp_b_differences(adp=None)[source]
Compute B-factor differences between bonded atoms.
- Parameters:
adp (torch.Tensor, optional) – ADP values. If None, uses the stored adp_fn callable.
- Returns:
Tensor of B-factor differences (B_i - B_j) for all bonds.
- Return type:
- adp_similarity_loss(adp=None, sigma=2.0)[source]
Compute ADP similarity loss (SIMU in Phenix/SHELX).
This restrains the B-factors of bonded atoms to be similar. Loss = Σ ((B_i - B_j) / sigma)^2
- Parameters:
adp (torch.Tensor, optional) – ADP values. If None, uses the stored adp_fn callable.
sigma (float, default 2.0) – Target standard deviation for B-factor differences in Ų.
- Returns:
Mean similarity loss.
- Return type: