"""
Restraints Class (Refactored) for Crystallographic Model Refinement
This module provides a refactored restraints handler using the builder pattern.
It maintains the same interface as the original Restraints class but uses
the more efficient and testable builder classes internally.
Key improvements:
- Single-pass iteration over residues (vs multiple passes in original)
- Pre-grouped residue data for O(N log N) vs O(N×R) complexity
- Sorted indices for cache-friendly tensor access
- Separated builder classes for easier testing and maintenance
- Decoupled from Model: accepts pdb DataFrame and callable functions for xyz/adp
"""
from typing import Callable, Optional
import numpy as np
import pandas as pd
import torch
from torch.nn import Module
from torchref.restraints.builders_fast import (
AngleRestraintBuilder,
BondRestraintBuilder,
ChiralRestraintBuilder,
InterResidueAngleBuilder,
InterResidueBondBuilder,
InterResiduePlaneBuilder,
InterResidueTorsionBuilder,
PlaneRestraintBuilder,
TorsionRestraintBuilder,
)
from torchref.restraints.restraints_helper import (
find_cif_file_in_library,
read_cif,
read_link_definitions,
)
from torchref.config import get_default_device, get_float_dtype
from torchref.utils.debug_utils import DebugMixin
from torchref.utils.utils import TensorDict
from torchref.utils.device_mixin import DeviceMixin
class _RestraintsAccessor:
"""
Provides backward-compatible dict-like access to restraints stored in TensorDict.
This class mimics the old nested dict interface:
restraints["bond"]["intra"]["indices"]
While actually accessing the TensorDict with flattened keys:
_tensor_storage["bond_intra_indices"]
"""
# Types that don't have origin level (assigned directly as dicts)
_FLAT_TYPES = {"vdw", "chiral"}
def __init__(self, parent: "RestraintsNew"):
self._parent = parent
def __getitem__(self, rtype: str) -> "_RestraintTypeAccessor":
return _RestraintTypeAccessor(self._parent, rtype)
def __setitem__(self, rtype: str, value):
"""Handle direct assignment for flat types like vdw and chiral."""
if rtype in self._FLAT_TYPES and isinstance(value, dict):
# Store all tensors with empty origin
self._parent._set_restraint_group(rtype, "", value)
else:
raise TypeError(
f"Cannot assign directly to restraints['{rtype}']. "
f"Use restraints['{rtype}'][origin] = data for nested types."
)
def __contains__(self, rtype: str) -> bool:
return len(self._parent._restraint_groups.get(rtype, set())) > 0 or \
rtype in self._FLAT_TYPES and self._parent._has_restraint(rtype, "")
def get(self, rtype: str, default=None):
if rtype in self:
return self[rtype]
return default
def keys(self):
"""Return all restraint types that have data."""
result = []
for rtype in ["bond", "angle", "torsion", "plane"]:
if len(self._parent._restraint_groups.get(rtype, set())) > 0:
result.append(rtype)
# Check for special types (vdw, chiral) which don't have origins
for rtype in self._FLAT_TYPES:
if self._parent._has_restraint(rtype, ""):
result.append(rtype)
return result
class _RestraintTypeAccessor:
"""
Provides access to origins within a restraint type.
For regular types (bond, angle, torsion, plane):
restraints["bond"]["intra"] -> dict with indices, references, sigmas
For special types (vdw, chiral), this class acts as the dict itself:
restraints["vdw"]["indices"] -> tensor
restraints["vdw"] = {"indices": ..., "sigmas": ...}
"""
# Types that don't have origin level (accessed directly as dicts)
_FLAT_TYPES = {"vdw", "chiral"}
def __init__(self, parent: "RestraintsNew", rtype: str):
self._parent = parent
self._rtype = rtype
def __getitem__(self, key: str):
if self._rtype in self._FLAT_TYPES:
# For vdw/chiral, key is a property name (indices, sigmas, etc.)
tensor = self._parent._get_restraint_tensor(self._rtype, "", key)
if tensor is None:
raise KeyError(f"No {key} for {self._rtype}")
return tensor
else:
# For bond/angle/torsion/plane, key is an origin name
result = self._parent._get_restraint_group(self._rtype, key)
if result is None:
raise KeyError(f"No restraints for {self._rtype}/{key}")
return result
def __setitem__(self, key: str, value):
if self._rtype in self._FLAT_TYPES:
# For vdw/chiral, if value is a tensor, store it directly
# If value is a dict, store all tensors
if isinstance(value, torch.Tensor):
self._parent._set_restraint_tensor(self._rtype, "", key, value)
elif isinstance(value, dict):
# This handles: restraints["vdw"] = {"indices": ..., "sigmas": ...}
# But this is called as restraints["vdw"][key] = value, so it won't work
# We need special handling in the parent accessor
pass
else:
# For bond/angle/torsion/plane, key is origin, value is dict
self._parent._set_restraint_group(self._rtype, key, value)
def __contains__(self, key: str) -> bool:
if self._rtype in self._FLAT_TYPES:
return self._parent._get_restraint_tensor(self._rtype, "", key) is not None
return self._parent._has_restraint(self._rtype, key)
def get(self, key: str, default=None):
try:
return self[key]
except KeyError:
return default
def keys(self):
if self._rtype in self._FLAT_TYPES:
# Return property names for flat types
result = []
for prop in ["indices", "references", "sigmas", "periods", "min_distances",
"symop_indices", "cell_offsets"]:
if self._parent._get_restraint_tensor(self._rtype, "", prop) is not None:
result.append(prop)
return result
return self._parent._get_origins_for_type(self._rtype)
def items(self):
if self._rtype in self._FLAT_TYPES:
for prop in self.keys():
yield prop, self._parent._get_restraint_tensor(self._rtype, "", prop)
else:
for origin in self.keys():
yield origin, self._parent._get_restraint_group(self._rtype, origin)
def __iter__(self):
return iter(self.keys())
[docs]
class RestraintsNew(DeviceMixin, DebugMixin, Module):
"""
Refactored restraints handler for crystallographic model refinement.
This class uses the builder pattern internally for efficient construction
of restraint tensors. It is decoupled from Model and accepts a pdb DataFrame
with callable functions for accessing coordinates and ADPs.
Parameters
----------
pdb : pd.DataFrame, optional
DataFrame containing atomic structure data. If None, creates empty shell.
cif_path : str or list of str, optional
Path to the CIF restraints dictionary file(s).
xyz_fn : callable, optional
Function returning current xyz coordinates as torch.Tensor.
Required for building and evaluation if pdb is provided.
adp_fn : callable, optional
Function returning current ADP values as torch.Tensor.
Required for ADP-based restraints.
vdw_radii_fn : callable, optional
Function returning VDW radii as torch.Tensor.
Required for VDW restraints.
verbose : int, default 1
Verbosity level (0=silent, 1=normal, 2=detailed).
Attributes
----------
pdb : pd.DataFrame
DataFrame containing atomic structure data.
xyz_fn : callable
Function returning current xyz coordinates.
adp_fn : callable
Function returning current ADP values.
vdw_radii_fn : callable
Function returning VDW radii.
cif_dict : dict
Parsed CIF dictionary with restraints for each residue type.
restraints : dict
Hierarchical dictionary containing all restraints.
"""
[docs]
def __init__(
self,
pdb: pd.DataFrame = None,
cif_path=None,
xyz_fn: Callable[[], torch.Tensor] = None,
adp_fn: Callable[[], torch.Tensor] = None,
vdw_radii_fn: Callable[[], torch.Tensor] = None,
cell=None,
spacegroup=None,
links: pd.DataFrame = None,
verbose: int = 1,
):
"""Initialize the Restraints handler."""
super().__init__()
self.cif_path = cif_path
self.verbose = verbose
self.links = links
# Store callable functions for coordinate/ADP access
self._xyz_fn = xyz_fn
self._adp_fn = adp_fn
self._vdw_radii_fn = vdw_radii_fn
# Store crystallographic info for symmetry VDW restraints
self._cell = cell
self._spacegroup = spacegroup
# Initialize TensorDict for restraint storage (registered as submodule)
self._tensor_storage = TensorDict()
# Track which restraint groups exist (for iteration)
self._restraint_groups = {"bond": set(), "angle": set(), "torsion": set(), "plane": set()}
# Empty initialization
if pdb is None:
self.pdb = None
self.cif_dict = {}
self.unique_residues = []
return
# Full initialization with pdb
self.pdb = pdb
self.unique_residues = pdb.resname.unique()
self.unique_residues = [
residue
for residue in self.unique_residues
if self.pdb.loc[self.pdb["resname"] == residue, "name"].nunique() > 1
]
# Parse CIF files
self._load_cif_dictionaries(cif_path)
# Load link definitions for inter-residue restraints
if verbose > 1:
print("Loading link definitions from monomer library...")
self.link_dict, self.link_list = read_link_definitions()
if verbose > 1:
print(f"Loaded {len(self.link_dict)} link types")
# Build restraints using the new builder pattern
self.build_restraints()
if self.verbose > 0:
self.summary()
[docs]
def xyz(self, xyz: torch.Tensor = None) -> torch.Tensor:
"""
Get current xyz coordinates.
Parameters
----------
xyz : torch.Tensor, optional
If provided, returns this tensor directly.
Otherwise calls the stored xyz_fn callable.
Returns
-------
torch.Tensor
Current xyz coordinates of shape (n_atoms, 3).
"""
if xyz is not None:
return xyz
if self._xyz_fn is None:
raise RuntimeError(
"No xyz callable provided. Initialize with xyz_fn or pass xyz argument."
)
return self._xyz_fn()
[docs]
def adp(self, adp: torch.Tensor = None) -> torch.Tensor:
"""
Get current ADP values.
Parameters
----------
adp : torch.Tensor, optional
If provided, returns this tensor directly.
Otherwise calls the stored adp_fn callable.
Returns
-------
torch.Tensor
Current ADP values of shape (n_atoms,).
"""
if adp is not None:
return adp
if self._adp_fn is None:
raise RuntimeError(
"No adp callable provided. Initialize with adp_fn or pass adp argument."
)
return self._adp_fn()
[docs]
def get_vdw_radii(self) -> torch.Tensor:
"""
Get VDW radii for all atoms.
Returns
-------
torch.Tensor
VDW radii of shape (n_atoms,).
"""
if self._vdw_radii_fn is None:
raise RuntimeError(
"No vdw_radii callable provided. Initialize with vdw_radii_fn."
)
return self._vdw_radii_fn()
# =========================================================================
# TensorDict Helper Methods for Restraint Storage
# =========================================================================
def _make_key(self, rtype: str, origin: str, prop: str) -> str:
"""Create flattened key for TensorDict storage."""
if origin:
return f"{rtype}_{origin}_{prop}"
else:
# For flat types (vdw, chiral) with no origin
return f"{rtype}_{prop}"
def _set_restraint_tensor(
self, rtype: str, origin: str, prop: str, tensor: torch.Tensor
):
"""Store a restraint tensor with flattened key."""
key = self._make_key(rtype, origin, prop)
self._tensor_storage[key] = tensor
# Track that this origin exists for this restraint type
if rtype in self._restraint_groups:
self._restraint_groups[rtype].add(origin)
def _get_restraint_tensor(
self, rtype: str, origin: str, prop: str
) -> Optional[torch.Tensor]:
"""Get a restraint tensor by type, origin, and property."""
key = self._make_key(rtype, origin, prop)
if key in self._tensor_storage:
return self._tensor_storage[key]
return None
def _has_restraint(self, rtype: str, origin: str) -> bool:
"""Check if a restraint group exists."""
key = self._make_key(rtype, origin, "indices")
return key in self._tensor_storage
def _set_restraint_group(self, rtype: str, origin: str, data: dict):
"""Store all tensors from a restraint data dict."""
for prop, tensor in data.items():
if tensor is not None and isinstance(tensor, torch.Tensor):
self._set_restraint_tensor(rtype, origin, prop, tensor)
def _get_restraint_group(self, rtype: str, origin: str) -> Optional[dict]:
"""Get all tensors for a restraint group as a dict."""
if not self._has_restraint(rtype, origin):
return None
result = {}
# Common properties for different restraint types
for prop in ["indices", "references", "sigmas", "periods", "min_distances",
"is_proline"]:
tensor = self._get_restraint_tensor(rtype, origin, prop)
if tensor is not None:
result[prop] = tensor
return result if result else None
def _get_origins_for_type(self, rtype: str) -> list:
"""Get all origins (e.g., 'intra', 'peptide') for a restraint type."""
return list(self._restraint_groups.get(rtype, set()))
@property
def restraints(self) -> "_RestraintsAccessor":
"""
Provide dict-like access to restraints for backward compatibility.
Returns an accessor object that mimics the old nested dict interface.
"""
return _RestraintsAccessor(self)
def _load_cif_dictionaries(self, cif_path):
"""Load CIF dictionaries from provided paths and monomer library."""
if cif_path:
if isinstance(cif_path, str):
try:
self.cif_dict = read_cif(cif_path)
except ValueError as e:
print("Error reading CIF file:", e)
raise
except Exception as e:
print("Error reading CIF file:", e)
self.cif_dict = {}
elif isinstance(cif_path, list):
self.cif_dict = {}
for cif_file in cif_path:
try:
cif_dict_part = read_cif(cif_file)
self.cif_dict.update(cif_dict_part)
except ValueError as e:
print("Error reading CIF file:", e)
raise
except Exception as e:
print("Error reading CIF file:", e)
else:
raise ValueError("cif_path must be a string or a list of strings")
else:
self.cif_dict = {}
# Load missing residues from monomer library
self.missing_residues = [
res for res in self.unique_residues if res not in self.cif_dict
]
additional_files = [
find_cif_file_in_library(res) for res in self.missing_residues
]
for cif_file in additional_files:
if cif_file is not None:
if self.verbose > 1:
print(cif_file)
try:
additional_cif_dict = read_cif(cif_file)
self.cif_dict.update(additional_cif_dict)
except Exception as e:
print("Error reading CIF file:", e)
print("This residue will have no restraints applied.")
self.missing_residues = [
res for res in self.unique_residues if res not in self.cif_dict
]
if len(self.missing_residues) > 1:
if self.verbose > 0:
print(
f"Warning: The following residues are missing from the CIF dictionary "
f"and will have no restraints applied: {self.missing_residues}"
)
[docs]
def expand_altloc(self, residue):
"""
Expand residue with alternative conformations into separate conformations.
Yields one DataFrame per altloc (with common atoms included in each).
"""
residue = residue.copy()
residue.loc[residue["altloc"].isin(["", " "]), "altloc"] = " "
alt_conf = residue["altloc"].unique()
if " " in alt_conf:
residue_no_alt = residue.loc[residue["altloc"] == " "]
for alt in alt_conf:
if alt == " ":
continue
residue_alt = residue.loc[residue["altloc"] == alt]
residue_combined = pd.concat(
[residue_no_alt, residue_alt], ignore_index=True
)
yield residue_combined
else:
for alt_loc in alt_conf:
residue_alt = residue.loc[residue["altloc"] == alt_loc]
yield residue_alt
def _load_rama_surfaces(self, device: torch.device):
"""Load pre-computed Ramachandran NLL surfaces as a buffer."""
from torchref.restraints.ramachandran import load_nll_surfaces
surfaces = load_nll_surfaces(device)
self.register_buffer("_rama_surfaces", surfaces)
[docs]
def build_restraints(self):
"""
Build all restraints using the fast builder API.
This method uses the optimized builders that handle all residues
internally with Numba-accelerated matching (~10x faster).
"""
try:
target_device = self.xyz().device
device = torch.device("cpu")
pdb = self.pdb
# Build intra-residue restraints using fast builders
# Each builder.build() handles all residues internally - no looping needed!
bond_result = BondRestraintBuilder(verbose=self.verbose).build(
pdb, self.cif_dict, device
)
if bond_result:
self.restraints["bond"]["intra"] = bond_result
angle_result = AngleRestraintBuilder(verbose=self.verbose).build(
pdb, self.cif_dict, device
)
if angle_result:
self.restraints["angle"]["intra"] = angle_result
torsion_result = TorsionRestraintBuilder(verbose=self.verbose).build(
pdb, self.cif_dict, device
)
if torsion_result:
self.restraints["torsion"]["intra"] = torsion_result
plane_result = PlaneRestraintBuilder(verbose=self.verbose).build(
pdb, self.cif_dict, device
)
if plane_result:
for key, data in plane_result.items():
self.restraints["plane"][key] = data
chiral_result = ChiralRestraintBuilder(verbose=self.verbose).build(
pdb, self.cif_dict, device
)
if chiral_result:
self.restraints["chiral"] = chiral_result
# Build inter-residue restraints
self._build_peptide_restraints(device)
self._build_disulfide_restraints(device)
self._build_link_restraints(device)
# Build VDW restraints. Cutoff is held ~1 Å wider than the
# maximum heavy-atom VDW sum (~3.6 Å) plus the expected inter-
# build drift, so the maintenance-triggered rebuild can be
# driven by a displacement threshold well inside the cutoff
# margin without missing newly-formed contacts.
self._build_vdw_restraints(
cutoff=6.0, sigma=0.05, inter_residue_only=False, use_spatial_hash=True
)
# Pre-compute concatenated 'all' groups so every buffer is registered
# on the correct device at build time. This prevents register_buffer()
# being called during a forward pass (which would break CUDA-graph
# capture) and ensures model.to(device) moves ALL restraint tensors.
self.cat_dict()
# Restraint construction (pair searches, CIF-driven topology) is
# built on CPU for predictability; move buffers to the model's
# device now so forward passes don't trigger H2D copies.
if target_device.type != "cpu":
self.to(target_device)
except Exception as e:
self.debug_on_error(e, context="RestraintsNew.build_restraints")
raise
def _build_peptide_restraints(self, device: torch.device):
"""Build peptide bond restraints using fast inter-residue builders.
Uses TRANS/CIS links for standard peptide bonds and PTRANS/PCIS
links for peptide bonds to proline. The proline-specific links
include the C(i-1)-N-CD angle that constrains the pyrrolidine
ring orientation, and use proline-specific angle target values.
"""
if "TRANS" not in self.link_dict:
if self.verbose > 0:
print(
"Warning: TRANS link not found in link dictionary, skipping peptide bonds"
)
return
trans_link = self.link_dict["TRANS"]
ptrans_link = self.link_dict.get("PTRANS")
pdb = self.pdb
# Build peptide bonds using fast builder
bond_result = InterResidueBondBuilder(verbose=self.verbose).build(
pdb, trans_link, device, filter_atom_type="ATOM"
)
if bond_result:
self.restraints["bond"]["peptide"] = bond_result
if self.verbose > 0:
print(
f"Built {bond_result['indices'].shape[0]} peptide bond restraints"
)
# Build peptide angles.
# If PTRANS is available, use it for proline pairs (excludes PRO
# from TRANS to avoid duplicate/conflicting restraints) and TRANS
# for non-proline pairs. Otherwise fall back to TRANS for all.
angle_builder = InterResidueAngleBuilder(verbose=self.verbose)
if ptrans_link is not None:
# Non-proline pairs: TRANS angles
angle_result = angle_builder.build(
pdb, trans_link, device, filter_atom_type="ATOM",
exclude_next_resname="PRO",
)
# Proline pairs: PTRANS angles (includes C-N-CD)
pro_angle_result = angle_builder.build(
pdb, ptrans_link, device, filter_atom_type="ATOM",
next_resname_filter="PRO",
)
# Merge results
if angle_result and pro_angle_result:
angle_result = {
"indices": torch.cat([angle_result["indices"], pro_angle_result["indices"]]),
"references": torch.cat([angle_result["references"], pro_angle_result["references"]]),
"sigmas": torch.cat([angle_result["sigmas"], pro_angle_result["sigmas"]]),
}
elif pro_angle_result:
angle_result = pro_angle_result
else:
angle_result = angle_builder.build(
pdb, trans_link, device, filter_atom_type="ATOM"
)
if angle_result:
self.restraints["angle"]["peptide"] = angle_result
if self.verbose > 0:
print(
f"Built {angle_result['indices'].shape[0]} peptide angle restraints"
)
# Build backbone torsions (phi, psi, omega)
torsion_result = InterResidueTorsionBuilder(verbose=self.verbose).build(
pdb, trans_link, device, filter_atom_type="ATOM"
)
if torsion_result:
if "phi" in torsion_result:
self.restraints["torsion"]["phi"] = torsion_result["phi"]
if "psi" in torsion_result:
self.restraints["torsion"]["psi"] = torsion_result["psi"]
if "omega" in torsion_result:
self.restraints["torsion"]["omega"] = torsion_result["omega"]
if "ramachandran" in torsion_result:
rama = torsion_result["ramachandran"]
self.register_buffer("_rama_phi_indices", rama["phi_indices"])
self.register_buffer("_rama_psi_indices", rama["psi_indices"])
self.register_buffer("_rama_surface_type", rama["surface_type"])
self._load_rama_surfaces(device)
# Build peptide planes
plane_result = InterResiduePlaneBuilder(verbose=self.verbose).build(
pdb, trans_link, device, filter_atom_type="ATOM"
)
if plane_result:
n_planes = 0
for key, data in plane_result.items():
n_planes += data["indices"].shape[0]
if self._has_restraint("plane", key):
# Append to existing planes of same atom count
existing = self.restraints["plane"][key]
self.restraints["plane"][key] = {
"indices": torch.cat(
[existing["indices"], data["indices"]], dim=0
),
"sigmas": torch.cat(
[existing["sigmas"], data["sigmas"]], dim=0
),
}
else:
self.restraints["plane"][key] = data
if self.verbose > 0:
print(f"Built {n_planes} peptide plane restraints")
def _build_disulfide_restraints(self, device: torch.device):
"""Build disulfide bond restraints."""
if "disulf" not in self.link_dict:
if self.verbose > 1:
print(
"Warning: disulf link not found in link dictionary, skipping disulfide bonds"
)
return
disulf_link = self.link_dict["disulf"]
disulf_bonds = disulf_link.get("bonds")
disulf_angles = disulf_link.get("angles")
disulf_torsions = disulf_link.get("torsions")
if disulf_bonds is None:
return
# Get SG-SG bond parameters
sg_sg_bond = disulf_bonds[
(disulf_bonds["atom1"] == "SG") & (disulf_bonds["atom2"] == "SG")
]
if len(sg_sg_bond) == 0:
return
bond_length = float(sg_sg_bond["value"].values[0])
bond_sigma = float(sg_sg_bond["sigma"].values[0])
# Find all SG atoms
pdb = self.pdb
sg_atoms = pdb[(pdb["name"] == "SG") & (pdb["ATOM"] == "ATOM")]
if len(sg_atoms) == 0:
return
# Get coordinates and find close pairs
xyz = self.xyz()
sg_indices = sg_atoms["index"].values
sg_coords = xyz[sg_indices]
sg_residues = (
sg_atoms["chainid"].astype(str) + "_" + sg_atoms["resseq"].astype(str)
).values
distances = torch.cdist(sg_coords, sg_coords)
threshold = 4.0
close_pairs = torch.where((distances < threshold) & (distances > 0.1))
valid_pairs = []
for i, j in zip(close_pairs[0].cpu().numpy(), close_pairs[1].cpu().numpy()):
if i < j and sg_residues[i] != sg_residues[j]:
valid_pairs.append((i, j))
if len(valid_pairs) == 0:
return
# Create builders
bond_builder = InterResidueBondBuilder(verbose=self.verbose)
angle_builder = InterResidueAngleBuilder(verbose=self.verbose)
torsion_builder = InterResidueTorsionBuilder(verbose=self.verbose)
# Process each disulfide bond
for i_local, j_local in valid_pairs:
sg1_idx = int(sg_indices[i_local])
sg2_idx = int(sg_indices[j_local])
# Add bond
bond_builder.process_disulfide_bond(
sg1_idx, sg2_idx, bond_length, bond_sigma
)
# Get residues for angle/torsion restraints
residue1 = pdb[pdb["index"] == sg1_idx].iloc[0]
residue2 = pdb[pdb["index"] == sg2_idx].iloc[0]
res1_atoms = pdb[
(pdb["chainid"] == residue1["chainid"])
& (pdb["resseq"] == residue1["resseq"])
]
res2_atoms = pdb[
(pdb["chainid"] == residue2["chainid"])
& (pdb["resseq"] == residue2["resseq"])
]
if disulf_angles is not None:
angle_builder.process_disulfide_angles(
res1_atoms, res2_atoms, disulf_angles
)
if disulf_torsions is not None:
torsion_builder.process_disulfide_torsions(
res1_atoms, res2_atoms, disulf_torsions
)
# Finalize
bond_result = bond_builder.finalize(device)
if bond_result:
self.restraints["bond"]["disulfide"] = bond_result
if self.verbose > 0:
print(
f"Built {bond_result['indices'].shape[0]} disulfide bond restraints"
)
angle_result = angle_builder.finalize(device)
if angle_result:
self.restraints["angle"]["disulfide"] = angle_result
if self.verbose > 0:
print(
f"Built {angle_result['indices'].shape[0]} disulfide angle restraints"
)
torsion_result = torsion_builder.finalize_disulfide(device)
if torsion_result:
self.restraints["torsion"]["disulfide"] = torsion_result
if self.verbose > 0:
print(
f"Built {torsion_result['indices'].shape[0]} disulfide torsion restraints"
)
def _build_link_restraints(self, device: torch.device):
"""Build bond restraints from PDB LINK records.
Each accepted LINK contributes one bond restraint between the two
named atoms. The bond automatically becomes part of the VDW
exclusion set (via ``_build_exclusion_set``), preventing the
non-bonded term from pushing the linked atoms apart.
Behaviour:
- Distance/sigma source: ``length`` from the LINK record is used as
target distance with sigma=0.02 Å. If the field was blank we fall
back to a generic 1.5 Å bond.
- Symmetry-mate links are filtered out earlier in
``extract_link_records``.
- LINKs that duplicate an auto-detected disulfide (CYS SG-SG pair)
are skipped, because the disulfide builder has already added a
bond + angles + torsions for that pair.
"""
if self.links is None or len(self.links) == 0:
return
pdb = self.pdb
# Already-bonded SG-SG pairs from auto-disulfide detection.
disulf = self.restraints.get("bond", {}).get("disulfide")
existing_disulf_pairs = set()
if disulf is not None and "indices" in disulf:
for i, j in disulf["indices"].cpu().numpy():
existing_disulf_pairs.add((int(min(i, j)), int(max(i, j))))
bond_builder = InterResidueBondBuilder(verbose=self.verbose)
n_skipped_unresolved = 0
n_skipped_dedup = 0
for _, link in self.links.iterrows():
idx1 = self._lookup_link_atom(
pdb,
chainid=link["chainid1"],
resseq=int(link["resseq1"]),
icode=link["icode1"],
resname=link["resname1"],
name=link["name1"],
altloc=link["altloc1"],
)
idx2 = self._lookup_link_atom(
pdb,
chainid=link["chainid2"],
resseq=int(link["resseq2"]),
icode=link["icode2"],
resname=link["resname2"],
name=link["name2"],
altloc=link["altloc2"],
)
if idx1 is None or idx2 is None:
n_skipped_unresolved += 1
if self.verbose > 1:
print(
f"Warning: LINK atom not found "
f"({link['chainid1']}/{link['resname1']}{link['resseq1']}/"
f"{link['name1']} -- "
f"{link['chainid2']}/{link['resname2']}{link['resseq2']}/"
f"{link['name2']}); skipping."
)
continue
if idx1 == idx2:
n_skipped_unresolved += 1
continue
pair = (min(idx1, idx2), max(idx1, idx2))
if pair in existing_disulf_pairs:
n_skipped_dedup += 1
continue
length = link["length"]
if not (isinstance(length, (int, float)) and length == length and length > 0):
length = 1.5
bond_builder.process_disulfide_bond(idx1, idx2, float(length), 0.02)
bond_result = bond_builder.finalize(device)
if bond_result:
self.restraints["bond"]["link"] = bond_result
if self.verbose > 0:
print(
f"Built {bond_result['indices'].shape[0]} LINK bond restraints"
+ (
f" (skipped {n_skipped_dedup} disulfide-dup,"
f" {n_skipped_unresolved} unresolved)"
if (n_skipped_dedup or n_skipped_unresolved)
else ""
)
)
@staticmethod
def _lookup_link_atom(
pdb: pd.DataFrame,
chainid: str,
resseq: int,
icode: str,
resname: str,
name: str,
altloc: str,
):
"""Resolve a LINK atom record to a row index in the model pdb.
Match on (chainid, resseq, icode, name); resname is used as a tie-
breaker if present. Altloc preference: requested altloc first, then
blank, then 'A', then any.
"""
sel = pdb[
(pdb["chainid"].astype(str) == str(chainid))
& (pdb["resseq"].astype(int) == int(resseq))
& (pdb["icode"].astype(str) == str(icode))
& (pdb["name"].astype(str).str.strip() == str(name).strip())
]
if len(sel) == 0:
return None
if resname:
tied = sel[sel["resname"].astype(str).str.strip() == str(resname).strip()]
if len(tied) > 0:
sel = tied
if altloc:
for cand in (altloc, ""):
hit = sel[sel["altloc"].astype(str) == cand]
if len(hit) > 0:
return int(hit.iloc[0]["index"])
for cand in ("", "A"):
hit = sel[sel["altloc"].astype(str) == cand]
if len(hit) > 0:
return int(hit.iloc[0]["index"])
return int(sel.iloc[0]["index"])
def _build_exclusion_set(self):
"""Build set of atom pairs to exclude from VDW calculations."""
exclusions = set()
# 1-2: Direct bonds
for origin in self.restraints.get("bond", {}).keys():
indices = self.restraints["bond"][origin].get("indices")
if indices is not None and len(indices) > 0:
idx_np = indices.cpu().numpy()
for i1, i2 in idx_np:
exclusions.add((int(min(i1, i2)), int(max(i1, i2))))
# 1-3: Angles
for origin in self.restraints.get("angle", {}).keys():
indices = self.restraints["angle"][origin].get("indices")
if indices is not None and len(indices) > 0:
idx_np = indices.cpu().numpy()
for i1, i2, i3 in idx_np:
exclusions.add((int(min(i1, i3)), int(max(i1, i3))))
# 1-4: Torsions
for origin in self.restraints.get("torsion", {}).keys():
indices = self.restraints["torsion"][origin].get("indices")
if indices is not None and len(indices) > 0:
idx_np = indices.cpu().numpy()
for i1, i2, i3, i4 in idx_np:
exclusions.add((int(min(i1, i4)), int(max(i1, i4))))
return exclusions
def _find_nearby_pairs_spatial_hash(self, xyz, cutoff=6.0):
"""
Find all atom pairs within cutoff distance using spatial cell lists.
Divides space into cubic cells of side length = cutoff and only checks
atom pairs in the same or adjacent cells (14 unique offsets: self + 13
forward neighbours). This gives O(N) memory and O(N*k) time where k
is the average number of neighbours, compared to O(N^2) for a full
distance matrix.
Parameters
----------
xyz : torch.Tensor
Atom coordinates of shape (N, 3).
cutoff : float
Distance cutoff in Angstroms.
Returns
-------
torch.Tensor
Pairs of atom indices, shape (M, 2), each row (i, j) with i < j.
"""
device = xyz.device
n_atoms = xyz.shape[0]
if n_atoms == 0:
return torch.tensor([], dtype=torch.long, device=device).reshape(0, 2)
# Work on CPU to avoid per-iteration GPU kernel launch overhead
coords = xyz.detach().cpu()
cell_size = cutoff
# Assign each atom to a cubic cell
xyz_min = coords.min(dim=0).values
cell_idx = ((coords - xyz_min) / cell_size).long() # (N, 3)
grid_dims = cell_idx.max(dim=0).values + 1
gx, gy, gz = grid_dims[0].item(), grid_dims[1].item(), grid_dims[2].item()
gyz = gy * gz
# Flat cell index per atom
flat = cell_idx[:, 0] * gyz + cell_idx[:, 1] * gz + cell_idx[:, 2]
# Sort atoms by cell so each cell's atoms are contiguous
order = flat.argsort()
sorted_flat = flat[order]
unique_cells, counts = torch.unique_consecutive(
sorted_flat, return_counts=True
)
n_unique = len(unique_cells)
starts = torch.zeros(n_unique + 1, dtype=torch.long)
starts[1:] = counts.cumsum(0)
# Lookup: flat_cell -> index in unique_cells (-1 if empty)
n_grid = gx * gyz
cell_lookup = torch.full((n_grid,), -1, dtype=torch.long)
cell_lookup[unique_cells] = torch.arange(n_unique)
# 14 unique neighbour offsets: self (0,0,0) + 13 forward neighbours.
# "Forward" = first non-zero component is positive, avoiding double counting.
offsets_list = []
for dx in range(-1, 2):
for dy in range(-1, 2):
for dz in range(-1, 2):
if (
dx > 0
or (dx == 0 and dy > 0)
or (dx == 0 and dy == 0 and dz >= 0)
):
offsets_list.append(
(dx, dy, dz, dx * gyz + dy * gz + dz)
)
cutoff_sq = cutoff * cutoff
pair_chunks = []
# Move to numpy for tight loop (faster item access than torch on CPU)
unique_np = unique_cells.numpy()
starts_np = starts.numpy()
order_np = order.numpy()
coords_np = coords.numpy()
for ci in range(n_unique):
cell_flat = int(unique_np[ci])
sa, ea = int(starts_np[ci]), int(starts_np[ci + 1])
atoms_a = order_np[sa:ea]
xyz_a = coords_np[atoms_a] # (na, 3)
cx = cell_flat // gyz
cy = (cell_flat % gyz) // gz
cz = cell_flat % gz
for dx, dy, dz, off_flat in offsets_list:
ncx, ncy, ncz = cx + dx, cy + dy, cz + dz
if (
ncx < 0 or ncx >= gx
or ncy < 0 or ncy >= gy
or ncz < 0 or ncz >= gz
):
continue
nb_flat = ncx * gyz + ncy * gz + ncz
nb_ci = int(cell_lookup[nb_flat])
if nb_ci < 0:
continue
sb, eb = int(starts_np[nb_ci]), int(starts_np[nb_ci + 1])
atoms_b = order_np[sb:eb]
xyz_b = coords_np[atoms_b] # (nb, 3)
# Vectorised distance² via broadcasting: (na, nb, 3)
diff = xyz_a[:, None, :] - xyz_b[None, :, :]
dist_sq = (diff * diff).sum(axis=-1) # (na, nb)
if off_flat == 0:
# Self-cell: upper triangle only
na = len(atoms_a)
if na < 2:
continue
ii, jj = np.triu_indices(na, k=1)
mask = dist_sq[ii, jj] < cutoff_sq
if mask.any():
ai = atoms_a[ii[mask]]
aj = atoms_a[jj[mask]]
pairs = np.stack(
[np.minimum(ai, aj), np.maximum(ai, aj)], axis=1
)
pair_chunks.append(pairs)
else:
# Inter-cell: all pairs
ii, jj = np.where(dist_sq < cutoff_sq)
if len(ii) > 0:
ai = atoms_a[ii]
bj = atoms_b[jj]
pairs = np.stack(
[np.minimum(ai, bj), np.maximum(ai, bj)], axis=1
)
pair_chunks.append(pairs)
if pair_chunks:
all_pairs = np.concatenate(pair_chunks, axis=0)
return torch.from_numpy(all_pairs).to(dtype=torch.long, device=device)
else:
return torch.tensor([], dtype=torch.long, device=device).reshape(0, 2)
def _expand_with_symmetry_mates(self, xyz, cutoff):
"""
Expand ASU coordinates with symmetry mate positions for neighbor search.
Generates Cartesian coordinates of symmetry-related copies that could
potentially have contacts with the ASU, using centroid-based pre-filtering
to skip distant mates.
Parameters
----------
xyz : torch.Tensor
ASU Cartesian coordinates of shape (N, 3).
cutoff : float
Distance cutoff in Angstroms for contact search.
Returns
-------
combined_xyz : torch.Tensor
Concatenated coordinates (N_asu + N_mates, 3).
provenance : dict
Dictionary with arrays describing the origin of each atom:
- 'asu_source_indices': (N_total,) int array, ASU atom index
- 'symop_indices': (N_total,) int array, symmetry operation index
- 'cell_offsets': (N_total, 3) int array, unit cell offset
"""
from torchref.config import dtypes
from torchref.symmetry import SpaceGroup
cell = self._cell
sg = self._spacegroup
if not isinstance(sg, SpaceGroup):
sg = SpaceGroup(sg)
n_asu = xyz.shape[0]
device = xyz.device
fdtype = dtypes.float
# Work on the model's device throughout
xyz_det = xyz.detach().to(fdtype)
xyz_frac = cell.cartesian_to_fractional(xyz_det)
# Compute centroid and molecule radius for pre-filtering
centroid_frac = xyz_frac.mean(dim=0)
centroid_cart = xyz_det.mean(dim=0)
molecule_radius = (xyz_det - centroid_cart).norm(dim=1).max().item()
threshold = 2 * molecule_radius + cutoff
B = cell.fractional_matrix.to(device=device, dtype=fdtype)
I_mat = torch.eye(3, dtype=fdtype, device=device)
# Phase 1: centroid pre-filter to find which (symop, offset) combos
# can produce contacts. This is a small loop over scalar ops.
n_ops = sg.n_ops
matrices = sg.matrices.to(device=device, dtype=fdtype)
translations = sg.translations.to(device=device, dtype=fdtype)
valid_ops = [] # list of (op_idx, dx, dy, dz)
for op_idx in range(n_ops):
R = matrices[op_idx]
t = translations[op_idx]
for dx in range(-1, 2):
for dy in range(-1, 2):
for dz in range(-1, 2):
if op_idx == 0 and dx == 0 and dy == 0 and dz == 0:
continue
offset = torch.tensor([dx, dy, dz], dtype=fdtype,
device=device)
d_frac = (R - I_mat) @ centroid_frac + t + offset
d_cart = B @ d_frac
if d_cart.norm().item() <= threshold:
valid_ops.append((op_idx, dx, dy, dz))
if not valid_ops:
provenance = {
"asu_source_indices": np.arange(n_asu, dtype=np.int64),
"symop_indices": np.zeros(n_asu, dtype=np.int64),
"cell_offsets": np.zeros((n_asu, 3), dtype=np.int64),
}
if self.verbose > 0:
print(" Symmetry expansion: 0 mate(s) within range "
f"({n_asu} total atoms for neighbor search)")
return xyz_det, provenance
# Phase 2: batch-generate all mate coordinates in one go
n_valid = len(valid_ops)
op_indices = [v[0] for v in valid_ops]
cell_offs = torch.tensor(
[[v[1], v[2], v[3]] for v in valid_ops], dtype=fdtype,
device=device,
) # (n_valid, 3)
# Gather rotation matrices and translations for valid ops
R_batch = matrices[op_indices] # (n_valid, 3, 3)
t_batch = translations[op_indices] # (n_valid, 3)
# Batched transform: for each valid op, compute R @ xyz_frac.T + t + offset
# xyz_frac: (N, 3), R_batch: (n_valid, 3, 3)
# -> (n_valid, 3, N) via batched matmul, then transpose to (n_valid, N, 3)
xyz_frac_T = xyz_frac.T.unsqueeze(0).expand(n_valid, -1, -1) # (n_valid, 3, N)
mate_frac_all = torch.bmm(R_batch, xyz_frac_T).permute(0, 2, 1) # (n_valid, N, 3)
mate_frac_all = mate_frac_all + t_batch.unsqueeze(1) + cell_offs.unsqueeze(1)
# Convert all to Cartesian: (n_valid * N, 3)
mate_frac_flat = mate_frac_all.reshape(-1, 3)
mate_cart_flat = cell.fractional_to_cartesian(mate_frac_flat)
# Build combined coordinate array: ASU + all mates
combined_xyz = torch.cat(
[xyz_det, mate_cart_flat], dim=0
)
# Build provenance arrays
asu_source = np.arange(n_asu, dtype=np.int64)
# ASU block
all_asu_sources = [asu_source]
all_symops = [np.zeros(n_asu, dtype=np.int64)]
all_offsets = [np.zeros((n_asu, 3), dtype=np.int64)]
# Mate blocks (each has n_asu atoms)
for op_idx, dx, dy, dz in valid_ops:
all_asu_sources.append(asu_source)
all_symops.append(np.full(n_asu, op_idx, dtype=np.int64))
all_offsets.append(np.tile([dx, dy, dz], (n_asu, 1)).astype(np.int64))
provenance = {
"asu_source_indices": np.concatenate(all_asu_sources),
"symop_indices": np.concatenate(all_symops),
"cell_offsets": np.concatenate(all_offsets),
}
if self.verbose > 0:
print(f" Symmetry expansion: {n_valid} mate(s) within range "
f"({combined_xyz.shape[0]} total atoms for neighbor search)")
return combined_xyz, provenance
@property
def h_topo(self):
"""Access riding hydrogen topology (None if not built)."""
return getattr(self, "_h_topo", None)
@property
def h_excl_hash(self):
"""Access H-specific exclusion hash tensor (None if not built)."""
return getattr(self, "_h_excl_hash", None)
def _build_h_exclusion_hash(self, h_topo, device):
"""Build sorted hash tensor for H-specific 1-2 and 1-3 exclusions.
Exclusions are stored as ``min(i, j) * max_idx + max(i, j)`` hashes,
sorted for O(log n) lookup via ``torch.searchsorted``.
Parameters
----------
h_topo : HydrogenTopology
device : torch.device
Returns
-------
torch.Tensor
Sorted 1-D long tensor of exclusion hashes.
"""
if h_topo is None or h_topo.n_hydrogens == 0:
return torch.tensor([], dtype=torch.long, device=device)
n_heavy = len(self.pdb)
n_h = h_topo.n_hydrogens
exclusions = set()
parent_idx = h_topo.h_parent_idx.cpu().numpy()
nb_idx = h_topo.parent_neighbor_idx.cpu().numpy()
nb_count = h_topo.parent_neighbor_count.cpu().numpy()
for hi in range(n_h):
# H index in the combined array is n_heavy + hi
h_combined = n_heavy + hi
p = int(parent_idx[hi])
# 1-2: H — parent
exclusions.add((min(h_combined, p), max(h_combined, p)))
# 1-3: H — parent's heavy neighbours
for ni in range(int(nb_count[hi])):
nb = int(nb_idx[hi, ni])
if nb >= 0:
exclusions.add((min(h_combined, nb), max(h_combined, nb)))
if not exclusions:
return torch.tensor([], dtype=torch.long, device=device)
arr = np.array(list(exclusions), dtype=np.int64)
max_idx = max(n_heavy + n_h, int(arr.max()) + 1)
hashes = arr[:, 0] * max_idx + arr[:, 1]
hashes.sort()
return torch.tensor(hashes, dtype=torch.long, device=device)
def _build_vdw_restraints(
self, cutoff=6.0, sigma=0.2, inter_residue_only=True, use_spatial_hash=True
):
"""Build van der Waals (non-bonded contact) restraints.
When cell and spacegroup are available, also includes contacts between
ASU atoms and symmetry-related copies in neighboring molecules.
Uses GPU-native periodic grid search when crystallographic symmetry
is available. Falls back to the legacy spatial hash otherwise.
Also builds the riding hydrogen topology for H-VDW evaluation.
Caches the build kwargs and a detached snapshot of the ASU
coordinates at build time in ``_vdw_build_kwargs`` and
``_last_vdw_build_xyz``. :meth:`rebuild_vdw_restraints` consults
those to refresh the pair list with the same parameters, and
``NonBondedTarget.maintenance`` uses the snapshot to decide
whether a rebuild is needed.
"""
# Remember how we built so rebuild can call back with the same
# parameters without re-plumbing them through every caller.
self._vdw_build_kwargs = dict(
cutoff=cutoff,
sigma=sigma,
inter_residue_only=inter_residue_only,
use_spatial_hash=use_spatial_hash,
)
if self.verbose > 0:
print("\nBuilding VDW (non-bonded) restraints...")
has_symmetry = (
self._cell is not None
and self._spacegroup is not None
)
# Restraint build (neighbor search, H topology, exclusion hashing)
# runs on CPU: pair lists are O(N) integers with launch overhead
# that dominates any GPU benefit, and the underlying searches are
# only called at build time. Newly-registered buffers, h_topo, and
# h_excl_hash are migrated to the model device at the end of this
# function so both the initial build and the maintenance-triggered
# rebuild path land on the right device.
cpu = torch.device("cpu")
target_device = self.xyz().device if self._xyz_fn is not None else cpu
def xyz_cpu():
return self.xyz().detach().to(cpu)
def vdw_radii_cpu():
return self.get_vdw_radii().detach().to(cpu)
# Construct fresh CPU copies — Cell/SpaceGroup ``.to()`` mutates
# in place, which would silently relocate the model's own Cell/SG.
if self._cell is not None:
from torchref.symmetry.cell import Cell
cell_cpu = Cell(self._cell._data.detach(), device=cpu,
dtype=self._cell.dtype)
else:
cell_cpu = None
if self._spacegroup is not None:
from torchref.symmetry.spacegroup import SpaceGroup
sg_cpu = SpaceGroup(self._spacegroup, device=cpu,
dtype=self._spacegroup._dtype)
else:
sg_cpu = None
if has_symmetry:
from torchref.restraints.neighbor_search import build_vdw_restraints_gpu
exclusions = self._build_exclusion_set()
self.restraints["vdw"] = build_vdw_restraints_gpu(
xyz_fn=xyz_cpu,
vdw_radii_fn=vdw_radii_cpu,
cell=cell_cpu,
sg=sg_cpu,
pdb=self.pdb,
exclusion_set=exclusions,
cutoff=cutoff,
sigma=sigma,
inter_residue_only=inter_residue_only,
verbose=self.verbose,
)
else:
self._build_vdw_restraints_legacy(
cutoff=cutoff, sigma=sigma,
inter_residue_only=inter_residue_only,
use_spatial_hash=use_spatial_hash,
)
# Build riding hydrogen topology and precompute candidate pairs
from torchref.restraints.hydrogen_topology import (
build_hydrogen_topology,
build_h_candidate_pairs,
)
self._h_topo = build_hydrogen_topology(
pdb=self.pdb,
device=cpu,
verbose=self.verbose,
)
self._h_excl_hash = self._build_h_exclusion_hash(self._h_topo, cpu)
# Precompute H candidate pairs from heavy-atom VDW pair list
vdw_data = self.restraints.get("vdw")
if vdw_data is not None and self._h_topo.n_hydrogens > 0:
build_h_candidate_pairs(
h_topo=self._h_topo,
vdw_data=vdw_data,
pdb=self.pdb,
h_excl_hash=self._h_excl_hash,
device=cpu,
verbose=self.verbose,
)
# Fill in VDW min distances using combined radii array
if self._h_topo.has_candidates:
heavy_radii = vdw_radii_cpu() # (N_heavy,)
h_radii = self._h_topo.h_vdw_radius # (N_h,) on CPU
all_radii = torch.cat([heavy_radii, h_radii])
self._h_topo.cand_min_dist = (
all_radii[self._h_topo.cand_idx_i]
+ all_radii[self._h_topo.cand_idx_j]
)
# Snapshot the ASU coordinates *at* build time so maintenance()
# callers can diff current positions against it and decide if a
# rebuild is needed. Detached clone lives on the model's device so
# the compare is a single op on whatever device xyz() returns.
if self._xyz_fn is not None:
self._last_vdw_build_xyz = self.xyz().detach().clone()
# Migrate the freshly-built VDW pair list, h_topo, and h_excl_hash
# from their CPU build device to the model's device. Required for
# both the initial build (the outer build_restraints also calls
# .to(target_device) — a no-op when already migrated) and for the
# maintenance-triggered rebuild, which has no surrounding migration.
if target_device.type != "cpu":
self.to(target_device)
[docs]
def rebuild_vdw_restraints(self) -> None:
"""Refresh the VDW pair list using the cached build kwargs.
Called by :meth:`NonBondedTarget.maintenance` after it detects
that the maximum atomic displacement since the last build has
exceeded the rebuild threshold. Uses the same ``cutoff``,
``sigma``, ``inter_residue_only`` and ``use_spatial_hash`` that
the initial build was given, so behaviour is stable across the
run.
"""
if not hasattr(self, "_vdw_build_kwargs"):
raise RuntimeError(
"rebuild_vdw_restraints called before initial build "
"— _vdw_build_kwargs is missing"
)
self._build_vdw_restraints(**self._vdw_build_kwargs)
def _build_vdw_restraints_legacy(
self, cutoff=5.0, sigma=0.2, inter_residue_only=True, use_spatial_hash=True
):
"""Legacy VDW restraint builder (no symmetry or CPU fallback)."""
exclusions = self._build_exclusion_set()
vdw_radii = self.get_vdw_radii()
xyz = self.xyz()
device = xyz.device
pdb = self.pdb
n_asu = xyz.shape[0]
# Expand with symmetry mates if crystallographic info is available
has_symmetry = (
self._cell is not None
and self._spacegroup is not None
)
if has_symmetry:
combined_xyz, provenance = self._expand_with_symmetry_mates(xyz, cutoff)
else:
combined_xyz = xyz
provenance = None
# Find nearby pairs in the (potentially expanded) coordinate set
if use_spatial_hash:
nearby_pairs = self._find_nearby_pairs_spatial_hash(combined_xyz, cutoff)
else:
n_total = combined_xyz.shape[0]
pairs_list = []
cutoff_sq = cutoff**2
for i in range(n_total):
for j in range(i + 1, n_total):
dist_sq = ((combined_xyz[i] - combined_xyz[j]) ** 2).sum()
if dist_sq < cutoff_sq:
pairs_list.append([i, j])
nearby_pairs = (
torch.tensor(pairs_list, dtype=torch.long, device=device)
if pairs_list
else torch.tensor([], dtype=torch.long, device=device).reshape(0, 2)
)
empty_result = {
"indices": torch.tensor([], dtype=torch.long, device=device).reshape(0, 2),
"min_distances": torch.tensor([], dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor([], dtype=get_float_dtype(), device=device),
"symop_indices": torch.tensor([], dtype=torch.long, device=device),
"cell_offsets": torch.tensor([], dtype=torch.long, device=device).reshape(0, 3),
}
if len(nearby_pairs) == 0:
self.restraints["vdw"] = empty_result
return
pairs_np = nearby_pairs.cpu().numpy()
# Map indices through provenance to get ASU source atoms and symop info
if provenance is not None:
prov_asu = provenance["asu_source_indices"]
prov_sym = provenance["symop_indices"]
prov_off = provenance["cell_offsets"]
# Get provenance for each atom in each pair
idx0 = pairs_np[:, 0]
idx1 = pairs_np[:, 1]
asu_src_0 = prov_asu[idx0]
asu_src_1 = prov_asu[idx1]
sym_0 = prov_sym[idx0]
sym_1 = prov_sym[idx1]
off_0 = prov_off[idx0]
off_1 = prov_off[idx1]
is_asu_0 = (sym_0 == 0) & (off_0 == 0).all(axis=1)
is_asu_1 = (sym_1 == 0) & (off_1 == 0).all(axis=1)
# Keep only pairs where at least one atom is from the ASU
has_asu = is_asu_0 | is_asu_1
pairs_np = pairs_np[has_asu]
asu_src_0 = asu_src_0[has_asu]
asu_src_1 = asu_src_1[has_asu]
sym_0 = sym_0[has_asu]
sym_1 = sym_1[has_asu]
off_0 = off_0[has_asu]
off_1 = off_1[has_asu]
is_asu_0 = is_asu_0[has_asu]
is_asu_1 = is_asu_1[has_asu]
# Normalize: put the ASU atom in position 0, mate in position 1
# For intra-ASU pairs (both ASU), keep as-is (both are ASU anyway)
# For symmetry pairs: swap so ASU is first
swap = ~is_asu_0 & is_asu_1
if swap.any():
asu_src_0[swap], asu_src_1[swap] = asu_src_1[swap].copy(), asu_src_0[swap].copy()
sym_0[swap], sym_1[swap] = sym_1[swap].copy(), sym_0[swap].copy()
off_0[swap], off_1[swap] = off_1[swap].copy(), off_0[swap].copy()
is_asu_0[swap] = True
is_asu_1[swap] = False
# Final indices: ASU atom indices for both atoms in each pair
final_i1 = asu_src_0
final_i2 = asu_src_1
# Symmetry info comes from the mate atom (position 1)
final_symop = sym_1
final_offsets = off_1
is_both_asu = is_asu_0 & is_asu_1
else:
# No symmetry: all pairs are intra-ASU
final_i1 = pairs_np[:, 0]
final_i2 = pairs_np[:, 1]
final_symop = np.zeros(len(pairs_np), dtype=np.int64)
final_offsets = np.zeros((len(pairs_np), 3), dtype=np.int64)
is_both_asu = np.ones(len(pairs_np), dtype=bool)
# --- Filtering ---
# Bonded exclusions, same-residue, and altloc filters apply only to
# intra-ASU pairs. Symmetry pairs cannot be bonded.
# Start with all pairs kept
keep_mask = np.ones(len(final_i1), dtype=bool)
# Exclusion mask (bonded 1-2, 1-3, 1-4) -- intra-ASU only
if exclusions and is_both_asu.any():
exclusion_arr = np.array(list(exclusions), dtype=np.int64)
max_idx = max(
pdb["index"].max() + 1,
final_i1[is_both_asu].max() + 1,
final_i2[is_both_asu].max() + 1,
)
# Normalize pair order for comparison
norm_i1 = np.minimum(final_i1, final_i2)
norm_i2 = np.maximum(final_i1, final_i2)
pair_hash = norm_i1 * max_idx + norm_i2
excl_hash = exclusion_arr[:, 0] * max_idx + exclusion_arr[:, 1]
is_excluded = np.isin(pair_hash, excl_hash)
# Only apply to intra-ASU pairs
keep_mask &= ~(is_excluded & is_both_asu)
# Inter-residue mask -- intra-ASU only
if inter_residue_only:
chainid_array = pdb["chainid"].values
resseq_array = pdb["resseq"].values
same_residue = (
(chainid_array[final_i1] == chainid_array[final_i2])
& (resseq_array[final_i1] == resseq_array[final_i2])
)
keep_mask &= ~(same_residue & is_both_asu)
# Altloc compatibility -- intra-ASU only
if "altloc" in pdb.columns:
altloc_array = pdb["altloc"].values.astype(str)
altloc_array = np.where(
np.isin(altloc_array, ["", " "]), " ", altloc_array
)
altloc_i = altloc_array[final_i1]
altloc_j = altloc_array[final_i2]
incompatible_altloc = (
(altloc_i != " ") & (altloc_j != " ") & (altloc_i != altloc_j)
)
keep_mask &= ~(incompatible_altloc & is_both_asu)
# Apply filter
final_i1 = final_i1[keep_mask]
final_i2 = final_i2[keep_mask]
final_symop = final_symop[keep_mask]
final_offsets = final_offsets[keep_mask]
if len(final_i1) == 0:
self.restraints["vdw"] = empty_result
return
# Compute min distances using VDW radii of ASU source atoms.
vdw_np = vdw_radii.cpu().numpy()
min_distances = vdw_np[final_i1] + vdw_np[final_i2]
# Store results
final_pairs = np.stack([final_i1, final_i2], axis=1)
self.restraints["vdw"] = {
"indices": torch.tensor(final_pairs, dtype=torch.long, device=device),
"min_distances": torch.tensor(
min_distances, dtype=get_float_dtype(), device=device
),
"sigmas": torch.full(
(len(final_pairs),), sigma, dtype=get_float_dtype(), device=device
),
"symop_indices": torch.tensor(
final_symop, dtype=torch.long, device=device
),
"cell_offsets": torch.tensor(
final_offsets, dtype=torch.long, device=device
),
}
if self.verbose > 0:
scope = "inter-residue" if inter_residue_only else "all"
msg = f" Built {len(final_pairs)} VDW restraints ({scope} contacts)"
if has_symmetry:
is_sym_pair = (final_symop != 0) | (final_offsets != 0).any(axis=1)
n_sym_count = int(is_sym_pair.sum())
msg += f", {n_sym_count} symmetry contacts"
print(msg)
# Device movement is handled automatically by TensorDict (registered as _tensor_storage)
# through PyTorch's Module.to(), cuda(), and cpu() methods
[docs]
def summary(self):
"""Print a detailed summary of all restraints."""
print("=" * 80)
print("Restraints Summary (New Implementation)")
print("=" * 80)
print(f"CIF file: {self.cif_path}")
print(f"Residue types in dictionary: {len(self.cif_dict)}")
print()
def get_count(rtype, origin):
indices = self.restraints.get(rtype, {}).get(origin, {}).get("indices")
return 0 if indices is None else indices.shape[0]
print("INTRA-RESIDUE RESTRAINTS:")
print("-" * 80)
print(f" Bonds: {get_count('bond', 'intra')}")
print(f" Angles: {get_count('angle', 'intra')}")
print(f" Torsions: {get_count('torsion', 'intra')}")
# Count planes
n_planes = 0
for key in self.restraints.get("plane", {}).keys():
n_planes += get_count("plane", key)
print(f" Planes: {n_planes}")
# Chiral
chiral_count = 0
if "chiral" in self.restraints:
indices = self.restraints["chiral"].get("indices")
chiral_count = 0 if indices is None else indices.shape[0]
print(f" Chirals: {chiral_count}")
print()
print("INTER-RESIDUE RESTRAINTS:")
print("-" * 80)
print(f" Peptide bonds: {get_count('bond', 'peptide')}")
print(f" Peptide angles: {get_count('angle', 'peptide')}")
print(f" Disulfide bonds: {get_count('bond', 'disulfide')}")
print(f" Disulfide angles: {get_count('angle', 'disulfide')}")
print(f" Disulfide torsions: {get_count('torsion', 'disulfide')}")
print(f" LINK bonds: {get_count('bond', 'link')}")
print()
print("BACKBONE TORSIONS:")
print("-" * 80)
print(f" Phi: {get_count('torsion', 'phi')}")
print(f" Psi: {get_count('torsion', 'psi')}")
print(f" Omega: {get_count('torsion', 'omega')}")
# Ramachandran
rama_count = 0
if hasattr(self, "_rama_phi_indices") and self._rama_phi_indices is not None:
rama_count = self._rama_phi_indices.shape[0]
if rama_count > 0:
print(f" Ramachandran: {rama_count}")
print()
print("VDW RESTRAINTS:")
print("-" * 80)
vdw_count = 0
vdw_sym_count = 0
if "vdw" in self.restraints:
indices = self.restraints["vdw"].get("indices")
vdw_count = 0 if indices is None else indices.shape[0]
symop_indices = self.restraints["vdw"].get("symop_indices")
cell_offsets = self.restraints["vdw"].get("cell_offsets")
if symop_indices is not None and len(symop_indices) > 0:
import torch as _torch
is_sym = (symop_indices != 0) | (cell_offsets != 0).any(dim=-1)
vdw_sym_count = int(is_sym.sum().item())
vdw_asu_count = vdw_count - vdw_sym_count
if vdw_sym_count > 0:
print(f" Non-bonded contacts: {vdw_count} ({vdw_asu_count} intra-ASU, {vdw_sym_count} symmetry)")
else:
print(f" Non-bonded contacts: {vdw_count}")
print("=" * 80)
[docs]
def __repr__(self):
"""Return string representation."""
def get_count(rtype, origin):
indices = self.restraints.get(rtype, {}).get(origin, {}).get("indices")
return 0 if indices is None else indices.shape[0]
n_bonds = get_count("bond", "intra")
n_angles = get_count("angle", "intra")
n_torsions = get_count("torsion", "intra")
n_bonds_peptide = get_count("bond", "peptide")
return (
f"RestraintsNew(bonds={n_bonds}, angles={n_angles}, "
f"torsions={n_torsions}, peptide_bonds={n_bonds_peptide})"
)
def _get_all_indices(self, restraint_type, keys_to_merge=None):
"""
Gather all indices of a given restraint type across all origins.
Parameters
----------
restraint_type : str
Type of restraint ('bond', 'angle', or 'torsion').
keys_to_merge : list of str, optional
Specific origins to include. If None, includes all origins.
Returns
-------
torch.Tensor or None
Concatenated tensor of all indices, or None if none exist.
"""
indices_list = []
for origin, data in self.restraints.get(restraint_type, {}).items():
indices = data.get("indices")
if indices is not None:
if keys_to_merge is None:
indices_list.append(indices)
elif origin in keys_to_merge:
indices_list.append(indices)
if not indices_list:
return None
return torch.cat(indices_list, dim=0)
def _get_all_property(self, restraint_type, property_name, keys_to_merge=None):
"""
Gather all values of a given property across all origins.
Parameters
----------
restraint_type : str
Type of restraint ('bond', 'angle', or 'torsion').
property_name : str
Property to gather ('references', 'sigmas', or 'periods').
keys_to_merge : list of str, optional
Specific origins to include. If None, includes all origins.
Returns
-------
torch.Tensor or None
Concatenated tensor of all property values, or None if none exist.
"""
values_list = []
for origin, data in self.restraints.get(restraint_type, {}).items():
values = data.get(property_name)
if values is not None:
if keys_to_merge is None:
values_list.append(values)
elif origin in keys_to_merge:
values_list.append(values)
if not values_list:
return None
return torch.cat(values_list, dim=0)
[docs]
def bond_lengths(self, idx, xyz: torch.Tensor = None):
"""
Compute current bond lengths from atomic coordinates.
Parameters
----------
idx : torch.Tensor
Bond indices tensor of shape (N, 2).
xyz : torch.Tensor, optional
Coordinates tensor of shape (n_atoms, 3).
If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of bond lengths of shape (N,).
"""
xyz = self.xyz(xyz)
if idx is None:
return torch.tensor([], device=xyz.device)
pos1 = xyz[idx[:, 0], :]
pos2 = xyz[idx[:, 1], :]
return torch.linalg.norm(pos2 - pos1, dim=-1)
[docs]
def copy(self):
"""
Create a deep copy of the Restraints object.
Returns
-------
Restraints
A deep copy of this Restraints instance.
"""
import copy
return copy.deepcopy(self)
[docs]
def bond_deviations(self, xyz: torch.Tensor = None):
"""
Compute bond length deviations and sigmas.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
deviations : torch.Tensor
Calculated minus expected bond lengths in Angstroms.
sigmas : torch.Tensor
Standard deviations from CIF library in Angstroms.
"""
if "all" not in self.restraints["bond"]:
self.cat_dict()
idx = self.restraints["bond"]["all"]["indices"]
references = self.restraints["bond"]["all"]["references"]
sigmas = self.restraints["bond"]["all"]["sigmas"]
# Get current bond lengths
bond_lengths = self.bond_lengths(idx, xyz)
deviations = bond_lengths - references
return deviations, sigmas
[docs]
def nll_bonds(self, xyz: torch.Tensor = None):
"""
Compute negative log-likelihood for bond length restraints.
For Gaussian distribution: NLL = -log(P(x|μ,σ))
NLL = 0.5 * ((x - μ) / σ)^2 + log(σ) + 0.5 * log(2π)
This is the true NLL where exp(-NLL) = probability density.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_bonds,) with negative log-likelihood values.
"""
from torchref.refinement.targets import gaussian_nll
deviations, sigmas = self.bond_deviations(xyz)
return gaussian_nll(deviations, sigmas)
[docs]
def angles(self, idx, xyz: torch.Tensor = None):
"""
Compute current angle values for all angle restraints.
Parameters
----------
idx : torch.Tensor
Angle indices tensor of shape (N, 3).
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_angles,) with current angle values in degrees.
"""
xyz = self.xyz(xyz)
pos1 = xyz[idx[:, 0], :]
pos2 = xyz[idx[:, 1], :]
pos3 = xyz[idx[:, 2], :]
# Compute vectors
v1 = pos1 - pos2 # Vector from atom2 to atom1
v2 = pos3 - pos2 # Vector from atom2 to atom3
# Compute angle using dot product
# cos(θ) = (v1 · v2) / (|v1| * |v2|)
dot_product = torch.sum(v1 * v2, dim=-1)
norm1 = torch.linalg.norm(v1, dim=-1)
norm2 = torch.linalg.norm(v2, dim=-1)
# Clamp to avoid numerical issues with arccos
cos_angle = torch.clamp(dot_product / (norm1 * norm2), -1.0, 1.0)
# Return angle in degrees
angles_rad = torch.acos(cos_angle)
angles_deg = torch.rad2deg(angles_rad)
return angles_deg
[docs]
def angle_deviations(self, xyz: torch.Tensor = None):
"""
Compute angle deviations and sigmas.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
deviations : torch.Tensor
Calculated minus expected angles in radians.
sigmas : torch.Tensor
Standard deviations in radians.
"""
if "all" not in self.restraints["angle"]:
self.cat_dict()
idx = self.restraints["angle"]["all"]["indices"]
references_rad = self.restraints["angle"]["all"]["references"] * (
torch.pi / 180.0
)
sigmas_rad = self.restraints["angle"]["all"]["sigmas"] * (torch.pi / 180.0)
calculated_rad = self.angles(idx, xyz) * (torch.pi / 180.0)
deviations = calculated_rad - references_rad
return deviations, sigmas_rad
[docs]
def nll_angles(self, xyz: torch.Tensor = None):
"""
Compute negative log-likelihood for angle restraints.
For Gaussian distribution: NLL = -log(P(x|μ,σ))
NLL = 0.5 * ((x - μ) / σ)^2 + log(σ) + 0.5 * log(2π)
This is the true NLL where exp(-NLL) = probability density.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_angles,) with negative log-likelihood values.
"""
from torchref.refinement.targets import gaussian_nll
deviations, sigmas = self.angle_deviations(xyz)
return gaussian_nll(deviations, sigmas)
[docs]
def cat_dict(self):
"""
Concatenate all restraint dictionaries into 'all' keys.
Creates restraints['bond']['all'], restraints['angle']['all'],
and restraints['torsion']['all'] by concatenating all origins.
"""
self.restraints["bond"]["all"] = {
"indices": self._get_all_indices("bond"),
"references": self._get_all_property("bond", "references"),
"sigmas": self._get_all_property("bond", "sigmas"),
}
self.restraints["angle"]["all"] = {
"indices": self._get_all_indices("angle"),
"references": self._get_all_property("angle", "references"),
"sigmas": self._get_all_property("angle", "sigmas"),
}
# Note: phi/psi origins are excluded because they have no reference
# values or sigmas (conformationally free). Omega is excluded here
# because it is handled by a dedicated OmegaTarget that uses a
# cis/trans von Mises mixture model.
_torsion_origins = ["intra", "disulfide"]
self.restraints["torsion"]["all"] = {
"indices": self._get_all_indices("torsion", _torsion_origins),
"references": self._get_all_property(
"torsion", "references", _torsion_origins
),
"sigmas": self._get_all_property(
"torsion", "sigmas", _torsion_origins
),
"periods": self._get_all_property(
"torsion", "periods", _torsion_origins
),
}
# Cache max period to avoid .item() GPU sync every iteration
periods = self.restraints["torsion"]["all"]["periods"]
if periods is not None and periods.numel() > 0:
self._torsion_max_period = int(periods.max().item())
else:
self._torsion_max_period = 1
[docs]
def torsions(self, idx, xyz: torch.Tensor = None):
"""
Compute current torsion angle values for all torsion restraints.
Parameters
----------
idx : torch.Tensor
Torsion indices tensor of shape (N, 4).
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_torsions,) with current torsion values in degrees.
"""
xyz = self.xyz(xyz)
pos1 = xyz[idx[:, 0], :]
pos2 = xyz[idx[:, 1], :]
pos3 = xyz[idx[:, 2], :]
pos4 = xyz[idx[:, 3], :]
# Compute torsion angles using vector math
b1 = pos2 - pos1
b2 = pos3 - pos2
b3 = pos4 - pos3
# Normalize b2 for projection
b2_norm = torch.linalg.norm(b2, dim=-1, keepdim=True)
b2_unit = b2 / b2_norm
# Compute normals to planes
n1 = torch.cross(b1, b2, dim=-1)
n2 = torch.cross(b2, b3, dim=-1)
# Normalize normals
n1_unit = n1 / torch.linalg.norm(n1, dim=-1, keepdim=True)
n2_unit = n2 / torch.linalg.norm(n2, dim=-1, keepdim=True)
# Compute angle between normals
m1 = torch.cross(n1_unit, b2_unit, dim=-1)
x = torch.sum(n1_unit * n2_unit, dim=-1)
y = torch.sum(m1 * n2_unit, dim=-1)
torsions_rad = torch.atan2(y, x)
torsions_deg = torch.rad2deg(torsions_rad)
return torsions_deg
def _wrap_torsion_periodicity(self, diff_rad, periods):
"""
Find minimum angular deviation considering n-fold rotational symmetry.
For period=n, angles differing by 360°/n are equivalent. This function
finds the equivalent angle with the smallest absolute deviation.
Parameters
----------
diff_rad : torch.Tensor
Tensor of angular deviations in radians (any shape).
periods : torch.Tensor
Tensor of periodicity values (same shape as diff_rad).
Period=0 or 1 means no symmetry (simple wrapping).
Period=n means n-fold rotational symmetry.
Returns
-------
torch.Tensor
Tensor of minimum wrapped deviations in radians (same shape as input).
Values are wrapped to [-π, π] and account for rotational symmetry.
Examples
--------
For period=6 (e.g., benzene), angles of 10°, 70°, 130°, 190°, 250°, 310°
are all equivalent. The function returns the one closest to 0°.
"""
# Clamp periods to minimum of 1 to avoid division by zero
periods_safe = torch.clamp(periods, min=1)
# Use cached max_period to avoid .item() GPU sync every iteration
max_period = getattr(self, "_torsion_max_period", None)
if max_period is None:
max_period = int(periods_safe.max().item())
if max_period > 1:
# Vectorized approach: generate all equivalent angles
device = diff_rad.device
original_shape = diff_rad.shape
# Flatten input for processing
diff_rad_flat = diff_rad.flatten()
periods_flat = periods_safe.flatten()
n_angles = len(diff_rad_flat)
# Create offset matrix: k * (2π / period) for k in [0, 1, ..., period-1]
# Shape: (n_angles, max_period)
k_range = torch.arange(max_period, device=device).unsqueeze(
0
) # (1, max_period)
periods_expanded = periods_flat.unsqueeze(1).float() # (n_angles, 1)
# Offsets for each angle: k * 2π/period
offsets = k_range * (
2.0 * torch.pi / periods_expanded
) # (n_angles, max_period)
# Apply offsets to differences: (n_angles, max_period)
diff_rad_expanded = diff_rad_flat.unsqueeze(1) # (n_angles, 1)
equiv_diffs = diff_rad_expanded - offsets # (n_angles, max_period)
# Wrap all equivalent angles to [-pi, pi]
equiv_diffs_wrapped = torch.remainder(
equiv_diffs + torch.pi, 2.0 * torch.pi
) - torch.pi
# Mask out invalid offsets (where k >= period for each angle)
valid_mask = k_range < periods_expanded # (n_angles, max_period)
# Set invalid positions to large value so they won't be selected
equiv_diffs_wrapped_masked = torch.where(
valid_mask,
torch.abs(equiv_diffs_wrapped),
torch.tensor(float("inf"), device=device),
)
# Find minimum absolute difference for each angle
min_indices = torch.argmin(equiv_diffs_wrapped_masked, dim=1) # (n_angles,)
# Gather the best wrapped difference for each angle
diff_wrapped_best = equiv_diffs_wrapped[
torch.arange(n_angles, device=device), min_indices
]
# Reshape back to original shape
return diff_wrapped_best.reshape(original_shape)
else:
# All periods are 0 or 1, simple wrapping
return torch.remainder(diff_rad + torch.pi, 2.0 * torch.pi) - torch.pi
[docs]
def torsion_deviations(self, xyz: torch.Tensor = None, wrapped=True):
"""
Compute deviations between calculated and expected torsion angles.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
wrapped : bool, default True
If True, wrap deviations accounting for periodicity.
If False, return raw deviations (calculated - expected).
Returns
-------
torch.Tensor
Tensor of shape (n_torsions,) with deviations in degrees.
For wrapped=True, deviations are in range appropriate for the period.
Notes
-----
Expected values from CIF library are discrete (typically -60°, 0°, 60°, 90°, 180°)
while calculated values from structure are continuous. This is correct!
Use wrapped=True for meaningful comparison and visualization.
"""
if "all" not in self.restraints["torsion"]:
self.cat_dict()
idx = self.restraints["torsion"]["all"]["indices"]
expected = self.restraints["torsion"]["all"]["references"]
periods = self.restraints["torsion"]["all"]["periods"]
calculated = self.torsions(idx, xyz)
if not wrapped:
# Simple difference
return calculated - expected
else:
# Use the helper function for periodicity handling
diff_rad = (calculated - expected) * torch.pi / 180.0
diff_wrapped_rad = self._wrap_torsion_periodicity(diff_rad, periods)
# Convert back to degrees
return torch.rad2deg(diff_wrapped_rad)
[docs]
def torsion_deviations_with_sigmas(self, xyz: torch.Tensor = None):
"""
Compute torsion deviations (wrapped for periodicity) and sigmas.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
deviations_rad : torch.Tensor
Wrapped deviations in radians.
sigmas_deg : torch.Tensor
Standard deviations in degrees (for von Mises NLL).
"""
if "all" not in self.restraints["torsion"]:
self.cat_dict()
idx = self.restraints["torsion"]["all"]["indices"]
expected = self.restraints["torsion"]["all"]["references"]
sigmas_deg = self.restraints["torsion"]["all"]["sigmas"]
periods = self.restraints["torsion"]["all"]["periods"]
calculated = self.torsions(idx, xyz)
# Wrap for periodicity
diff_rad = (calculated - expected) * (torch.pi / 180.0)
deviations_rad = self._wrap_torsion_periodicity(diff_rad, periods)
return deviations_rad, sigmas_deg
[docs]
def nll_torsions(self, xyz: torch.Tensor = None):
"""
Compute negative log-likelihood for torsion angle restraints.
For von Mises distribution: NLL = -log(P(θ|μ,κ))
NLL = -κ*cos(θ-μ) + log(I₀(κ)) + log(2π)
where κ = 1/σ² is the concentration parameter and I₀ is the modified
Bessel function of the first kind.
Notes
-----
Period indicates n-fold rotational symmetry (e.g., period=6 for benzene).
We handle this by finding the minimum angular distance considering periodicity.
For period=n, angles differing by 360°/n are equivalent.
This is the true NLL where exp(-NLL) = probability density.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_torsions,) with negative log-likelihood values.
"""
from torchref.refinement.targets import von_mises_nll
deviations_rad, sigmas_deg = self.torsion_deviations_with_sigmas(xyz)
return von_mises_nll(deviations_rad, sigmas_deg)
[docs]
def nll_planes(self, xyz: torch.Tensor = None):
"""
Compute negative log-likelihood for plane restraints.
For each plane, computes the RMSD of atom deviations from the best-fit plane.
Uses Gaussian NLL: NLL = 0.5 * (deviation / σ)² + log(σ) + 0.5 * log(2π)
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_planes,) with negative log-likelihood values.
"""
from torchref.refinement.targets import gaussian_nll
xyz = self.xyz(xyz)
device = xyz.device
all_nlls = []
if "plane" in self.restraints:
for key, plane_data in self.restraints["plane"].items():
indices = plane_data.get("indices")
sigmas = plane_data.get("sigmas")
if indices is None or len(indices) == 0:
continue
# indices shape: (n_planes, n_atoms_per_plane)
# sigmas shape: (n_planes, n_atoms_per_plane)
n_planes, n_atoms = indices.shape
for i in range(n_planes):
plane_indices = indices[i]
plane_sigmas = sigmas[i]
# Get positions of atoms in this plane
positions = xyz[plane_indices] # (n_atoms, 3)
# Compute centroid
centroid = positions.mean(dim=0)
centered = positions - centroid
# SVD to find best-fit plane normal
# The plane normal is the singular vector with smallest singular value
U, S, Vh = torch.linalg.svd(centered)
normal = Vh[-1] # Normal to best-fit plane
# Compute deviations from plane (distance to plane)
deviations = torch.abs(centered @ normal)
# Compute NLL for each atom
nll = gaussian_nll(deviations, plane_sigmas)
all_nlls.append(nll)
if all_nlls:
return torch.cat(all_nlls)
return torch.tensor([0.0], device=device)
[docs]
def nll_vdw(self, xyz: torch.Tensor = None):
"""
Compute negative log-likelihood for VDW (non-bonded) restraints.
Uses a soft-repulsive potential based on distance violations.
NLL = 0.5 * (max(0, min_dist - actual_dist) / σ)² + log(σ) + 0.5 * log(2π)
Only violations (distances shorter than minimum) contribute to the loss.
Parameters
----------
xyz : torch.Tensor, optional
Coordinates tensor. If None, uses the stored xyz_fn callable.
Returns
-------
torch.Tensor
Tensor of shape (n_pairs,) with negative log-likelihood values.
"""
from torchref.refinement.targets import gaussian_nll
xyz = self.xyz(xyz)
device = xyz.device
if "vdw" not in self.restraints:
return torch.tensor([0.0], device=device)
vdw_data = self.restraints["vdw"]
indices = vdw_data.get("indices")
if indices is None or len(indices) == 0:
return torch.tensor([0.0], device=device)
min_distances = vdw_data["min_distances"]
sigmas = vdw_data["sigmas"]
# Get current positions
pos1 = xyz[indices[:, 0]]
pos2 = xyz[indices[:, 1]]
# Compute actual distances
actual_distances = torch.norm(pos2 - pos1, dim=-1)
# Violations: where actual distance is less than minimum
# Deviation = max(0, min_dist - actual_dist)
deviations = torch.clamp(min_distances - actual_distances, min=0.0)
# Compute NLL (only non-zero for violations)
nll = gaussian_nll(deviations, sigmas)
return nll
[docs]
def adp_b_differences(self, adp: torch.Tensor = None):
"""
Compute B-factor differences between bonded atoms.
Parameters
----------
adp : torch.Tensor, optional
ADP values. If None, uses the stored adp_fn callable.
Returns
-------
torch.Tensor
Tensor of B-factor differences (B_i - B_j) for all bonds.
"""
b_factors = self.adp(adp)
diffs_list = []
if "bond" in self.restraints:
for origin, restraint_group in self.restraints["bond"].items():
if origin == "all":
continue
indices = restraint_group.get("indices")
if indices is not None and len(indices) > 0:
b1 = b_factors[indices[:, 0]]
b2 = b_factors[indices[:, 1]]
diffs_list.append(b1 - b2)
if diffs_list:
return torch.cat(diffs_list, dim=0)
return torch.tensor([], device=b_factors.device)
[docs]
def adp_similarity_loss(self, adp: torch.Tensor = None, sigma: float = 2.0):
"""
Compute ADP similarity loss (SIMU in Phenix/SHELX).
This restrains the B-factors of bonded atoms to be similar.
Loss = Σ ((B_i - B_j) / sigma)^2
Parameters
----------
adp : torch.Tensor, optional
ADP values. If None, uses the stored adp_fn callable.
sigma : float, default 2.0
Target standard deviation for B-factor differences in Ų.
Returns
-------
torch.Tensor
Mean similarity loss.
"""
from torchref.refinement.targets import adp_similarity_nll
b_diffs = self.adp_b_differences(adp)
if len(b_diffs) == 0:
return torch.tensor(0.0, device=self.xyz().device)
return adp_similarity_nll(b_diffs, sigma).mean()