"""
Restraint Builder Classes for Crystallographic Refinement
This module provides a hierarchy of builder classes for constructing
geometry restraints efficiently. The base class handles common functionality
while specialized child classes implement restraint-type-specific logic.
The builder pattern allows:
- Single-pass iteration over residues
- Easy unit testing of individual components
- Memory-efficient accumulation and finalization
- Sorted indices for cache-friendly access
Classes
-------
ResidueIterator
Efficient iterator over residues with pre-grouped data.
RestraintBuilder
Abstract base class for all restraint builders.
BondRestraintBuilder
Builder for bond length restraints.
AngleRestraintBuilder
Builder for angle restraints.
TorsionRestraintBuilder
Builder for torsion angle restraints.
PlaneRestraintBuilder
Builder for planarity restraints.
ChiralRestraintBuilder
Builder for chiral volume restraints.
"""
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
from torchref.config import get_float_dtype
[docs]
class ResidueIterator:
"""
Efficient iterator over residues with pre-grouped data.
Pre-groups PDB data by (chainid, resseq) once at initialization,
avoiding O(N) DataFrame filtering per residue. This reduces the
complexity from O(N×R) to O(N log N) where N is the number of atoms
and R is the number of residues.
Parameters
----------
pdb : pd.DataFrame
PDB DataFrame with columns 'chainid', 'resseq', 'resname', 'name', 'index'.
filter_atom_type : str, optional
If provided, only include atoms with this ATOM type (e.g., 'ATOM' for protein).
Attributes
----------
pdb : pd.DataFrame
Reference to the input DataFrame.
groups : list of tuple
List of (chainid, resseq) tuples for iteration.
Examples
--------
::
iterator = ResidueIterator(model.pdb)
for chain_id, resseq, residue_df in iterator:
print(f"Processing {chain_id}:{resseq}")
"""
[docs]
def __init__(self, pdb: pd.DataFrame, filter_atom_type: Optional[str] = None):
"""Initialize with pre-grouping of residues."""
if filter_atom_type is not None:
pdb = pdb[pdb["ATOM"] == filter_atom_type]
self.pdb = pdb
# Pre-group once - O(N log N) instead of O(N×R)
self._grouped = pdb.groupby(["chainid", "resseq"], sort=False)
self.groups = list(self._grouped.groups.keys())
[docs]
def __iter__(self) -> Iterator[Tuple[str, int, pd.DataFrame]]:
"""
Iterate over residues.
Yields
------
chain_id : str
Chain identifier.
resseq : int
Residue sequence number.
residue : pd.DataFrame
DataFrame containing all atoms in this residue.
"""
for chain_id, resseq in self.groups:
residue = self._grouped.get_group((chain_id, resseq))
yield chain_id, resseq, residue
[docs]
def __len__(self) -> int:
"""Return number of residues."""
return len(self.groups)
[docs]
def get_consecutive_pairs(self) -> Iterator[Tuple[pd.DataFrame, pd.DataFrame]]:
"""
Iterate over consecutive residue pairs within each chain.
Useful for building inter-residue restraints (peptide bonds, etc.).
Yields
------
residue_i : pd.DataFrame
First residue in the pair.
residue_next : pd.DataFrame
Next consecutive residue (resseq + 1).
"""
# Group by chain first
by_chain = {}
for chain_id, resseq in self.groups:
if chain_id not in by_chain:
by_chain[chain_id] = []
by_chain[chain_id].append(resseq)
for chain_id, resseqs in by_chain.items():
resseqs_sorted = sorted(resseqs)
for i in range(len(resseqs_sorted) - 1):
resseq_i = resseqs_sorted[i]
resseq_next = resseqs_sorted[i + 1]
# Skip if not consecutive (chain break)
if resseq_next != resseq_i + 1:
continue
residue_i = self._grouped.get_group((chain_id, resseq_i))
residue_next = self._grouped.get_group((chain_id, resseq_next))
yield residue_i, residue_next
[docs]
class RestraintBuilder(ABC):
"""
Abstract base class for restraint builders.
Provides common functionality for accumulating restraint data during
residue iteration and finalizing to sorted tensors. Child classes
implement the specific logic for each restraint type.
Parameters
----------
verbose : int, default 0
Verbosity level for debug output.
Attributes
----------
_indices : list of np.ndarray
Accumulated index arrays.
_references : list of np.ndarray
Accumulated reference value arrays.
_sigmas : list of np.ndarray
Accumulated sigma arrays.
_count : int
Total number of restraints added.
"""
# Class attributes to be overridden by child classes
restraint_type: str = "base"
atom_columns: List[str] = []
n_atoms: int = 0
[docs]
def __init__(self, verbose: int = 0):
"""Initialize empty accumulator lists."""
self.verbose = verbose
self._indices: List[np.ndarray] = []
self._references: List[np.ndarray] = []
self._sigmas: List[np.ndarray] = []
self._count: int = 0
[docs]
def reset(self):
"""Clear all accumulated data."""
self._indices.clear()
self._references.clear()
self._sigmas.clear()
self._count = 0
[docs]
@abstractmethod
def process_residue(
self, residue: pd.DataFrame, cif_restraints: pd.DataFrame
) -> int:
"""
Process a single residue, matching atoms to CIF restraints.
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
cif_restraints : pd.DataFrame
Restraints from CIF dictionary for this residue type.
Returns
-------
int
Number of restraints added.
"""
pass
def _filter_usable_restraints(
self, cif_restraints: pd.DataFrame, atom_names: set
) -> pd.DataFrame:
"""
Filter CIF restraints to those where all required atoms are present.
Parameters
----------
cif_restraints : pd.DataFrame
All restraints from CIF for this residue type.
atom_names : set
Set of atom names present in the residue.
Returns
-------
pd.DataFrame
Filtered restraints where all atoms exist.
"""
mask = np.all(
[cif_restraints[col].isin(atom_names).values for col in self.atom_columns],
axis=0,
)
return cif_restraints.loc[mask]
def _build_name_to_index_map(
self, residue: pd.DataFrame
) -> Optional[Dict[str, int]]:
"""
Create atom name to index mapping, checking for duplicates.
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
Returns
-------
dict or None
Mapping from atom name to index, or None if duplicates exist.
"""
residue_indexed = residue.set_index("name")
if residue_indexed.index.has_duplicates:
if self.verbose > 2:
resname = residue["resname"].iloc[0] if len(residue) > 0 else "UNKNOWN"
chain_id = residue["chainid"].iloc[0] if len(residue) > 0 else "UNKNOWN"
resseq = residue["resseq"].iloc[0] if len(residue) > 0 else "UNKNOWN"
print(
f"WARNING: Skipping {self.restraint_type} restraints for "
f"{resname} {chain_id} {resseq} (duplicate atom names)"
)
return None
return dict(zip(residue["name"], residue["index"]))
def _map_atoms_to_indices(
self, usable_restraints: pd.DataFrame, name_to_idx: Dict[str, int]
) -> np.ndarray:
"""
Map atom names in restraints to atom indices.
Parameters
----------
usable_restraints : pd.DataFrame
Filtered restraints.
name_to_idx : dict
Mapping from atom name to index.
Returns
-------
np.ndarray
Array of shape (n_restraints, n_atoms) with atom indices.
"""
indices = np.array(
[
[name_to_idx[name] for name in usable_restraints[col].values]
for col in self.atom_columns
]
).T
return indices
def _accumulate(
self, indices: np.ndarray, references: np.ndarray, sigmas: np.ndarray
):
"""
Add restraint data to accumulator lists.
Parameters
----------
indices : np.ndarray
Atom indices array of shape (n_restraints, n_atoms).
references : np.ndarray
Reference values array of shape (n_restraints,).
sigmas : np.ndarray
Sigma values array of shape (n_restraints,).
"""
self._indices.append(indices)
self._references.append(references)
self._sigmas.append(sigmas)
self._count += len(indices)
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
"""
Convert accumulated data to sorted tensors.
Parameters
----------
device : torch.device
Target device for tensors.
sort_indices : bool, default True
Whether to sort by first atom index for memory locality.
min_sigma : float, default 1e-4
Minimum sigma value to avoid division by zero.
Returns
-------
dict or None
Dictionary with 'indices', 'references', 'sigmas' tensors,
or None if no restraints were accumulated.
"""
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
# Sort by first atom index for cache-friendly access
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
# Replace zero sigmas
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
}
@property
def count(self) -> int:
"""Return total number of restraints accumulated."""
return self._count
[docs]
class BondRestraintBuilder(RestraintBuilder):
"""
Builder for bond length restraints.
Bond restraints define expected distances between pairs of bonded atoms.
Each restraint specifies: atom1, atom2, target_distance, sigma.
Examples
--------
::
builder = BondRestraintBuilder()
for residue in residues:
builder.process_residue(residue, cif_bonds)
result = builder.finalize(device)
print(result['indices'].shape) # (n_bonds, 2)
"""
restraint_type = "bond"
atom_columns = ["atom1", "atom2"]
n_atoms = 2
[docs]
def process_residue(
self, residue: pd.DataFrame, cif_restraints: pd.DataFrame
) -> int:
"""
Process bond restraints for a single residue.
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
cif_restraints : pd.DataFrame
Bond restraints with 'atom1', 'atom2', 'value', 'sigma' columns.
Returns
-------
int
Number of bond restraints added.
"""
atom_names = set(residue["name"].values)
usable = self._filter_usable_restraints(cif_restraints, atom_names)
if len(usable) == 0:
return 0
name_to_idx = self._build_name_to_index_map(residue)
if name_to_idx is None:
return 0
indices = self._map_atoms_to_indices(usable, name_to_idx)
references = usable["value"].values.astype(np.float32)
sigmas = usable["sigma"].values.astype(np.float32)
self._accumulate(indices, references, sigmas)
return len(usable)
[docs]
class AngleRestraintBuilder(RestraintBuilder):
"""
Builder for angle restraints.
Angle restraints define expected angles between three atoms (1-2-3).
The angle is measured at the central atom (atom2).
Examples
--------
::
builder = AngleRestraintBuilder()
for residue in residues:
builder.process_residue(residue, cif_angles)
result = builder.finalize(device)
print(result['indices'].shape) # (n_angles, 3)
"""
restraint_type = "angle"
atom_columns = ["atom1", "atom2", "atom3"]
n_atoms = 3
[docs]
def process_residue(
self, residue: pd.DataFrame, cif_restraints: pd.DataFrame
) -> int:
"""
Process angle restraints for a single residue.
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
cif_restraints : pd.DataFrame
Angle restraints with 'atom1', 'atom2', 'atom3', 'value', 'sigma' columns.
Returns
-------
int
Number of angle restraints added.
"""
atom_names = set(residue["name"].values)
usable = self._filter_usable_restraints(cif_restraints, atom_names)
if len(usable) == 0:
return 0
name_to_idx = self._build_name_to_index_map(residue)
if name_to_idx is None:
return 0
indices = self._map_atoms_to_indices(usable, name_to_idx)
references = usable["value"].values.astype(np.float32)
sigmas = usable["sigma"].values.astype(np.float32)
self._accumulate(indices, references, sigmas)
return len(usable)
[docs]
class TorsionRestraintBuilder(RestraintBuilder):
"""
Builder for torsion angle restraints.
Torsion restraints define expected dihedral angles between four atoms.
Includes periodicity information for symmetric torsions (e.g., period=3
for methyl groups).
Attributes
----------
_periods : list of np.ndarray
Accumulated periodicity arrays.
Examples
--------
::
builder = TorsionRestraintBuilder()
for residue in residues:
builder.process_residue(residue, cif_torsions)
result = builder.finalize(device)
print(result['indices'].shape) # (n_torsions, 4)
print(result['periods'].shape) # (n_torsions,)
"""
restraint_type = "torsion"
atom_columns = ["atom1", "atom2", "atom3", "atom4"]
n_atoms = 4
[docs]
def __init__(self, verbose: int = 0):
"""Initialize with additional periods accumulator."""
super().__init__(verbose)
self._periods: List[np.ndarray] = []
[docs]
def reset(self):
"""Clear all accumulated data including periods."""
super().reset()
self._periods.clear()
[docs]
def process_residue(
self, residue: pd.DataFrame, cif_restraints: pd.DataFrame
) -> int:
"""
Process torsion restraints for a single residue.
Filters out torsions with sigma=0 (undefined/flexible torsions).
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
cif_restraints : pd.DataFrame
Torsion restraints with 'atom1'-'atom4', 'value', 'sigma',
and optionally 'periodicity' columns.
Returns
-------
int
Number of torsion restraints added.
"""
atom_names = set(residue["name"].values)
usable = self._filter_usable_restraints(cif_restraints, atom_names)
if len(usable) == 0:
return 0
name_to_idx = self._build_name_to_index_map(residue)
if name_to_idx is None:
return 0
indices = self._map_atoms_to_indices(usable, name_to_idx)
references = usable["value"].values.astype(np.float32)
sigmas = usable["sigma"].values.astype(np.float32)
# Get periodicity if available
if "periodicity" in usable.columns:
periods = usable["periodicity"].values.astype(np.int64)
else:
periods = np.ones(len(usable), dtype=np.int64)
# Filter out torsions with sigma=0 (these are flexible/undefined)
valid_mask = sigmas != 0
if not np.any(valid_mask):
return 0
indices = indices[valid_mask]
references = references[valid_mask]
sigmas = sigmas[valid_mask]
periods = periods[valid_mask]
self._accumulate(indices, references, sigmas)
self._periods.append(periods)
return len(indices)
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
"""
Convert accumulated data to sorted tensors, including periods.
Parameters
----------
device : torch.device
Target device for tensors.
sort_indices : bool, default True
Whether to sort by first atom index for memory locality.
min_sigma : float, default 1e-4
Minimum sigma value.
Returns
-------
dict or None
Dictionary with 'indices', 'references', 'sigmas', 'periods' tensors.
"""
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
periods = np.concatenate(self._periods)
# Sort by first atom index
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
periods = periods[sort_order]
# Replace zero sigmas (shouldn't happen after filtering, but safety)
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
"periods": torch.tensor(periods, dtype=torch.long, device=device),
}
[docs]
class PlaneRestraintBuilder(RestraintBuilder):
"""
Builder for planarity restraints.
Plane restraints define groups of atoms that should be coplanar.
Unlike other restraints, planes have variable atom counts (3-10 atoms).
Results are grouped by atom count for efficient tensor operations.
Attributes
----------
_planes_by_count : dict
Dictionary mapping atom_count -> {'indices': [...], 'sigmas': [...]}.
Examples
--------
::
builder = PlaneRestraintBuilder()
for residue in residues:
builder.process_residue(residue, cif_planes)
result = builder.finalize(device)
# Returns dict like {'4_atoms': {...}, '6_atoms': {...}}
"""
restraint_type = "plane"
atom_columns = ["atom"] # Planes use single atom column with plane_id grouping
n_atoms = None # Variable
[docs]
def __init__(self, verbose: int = 0):
"""Initialize with planes-by-count dictionary."""
super().__init__(verbose)
self._planes_by_count: Dict[int, Dict[str, List]] = {}
[docs]
def reset(self):
"""Clear all accumulated data."""
super().reset()
self._planes_by_count.clear()
[docs]
def process_residue(
self, residue: pd.DataFrame, cif_restraints: pd.DataFrame
) -> int:
"""
Process plane restraints for a single residue.
Groups plane atoms by plane_id and stores separately by atom count.
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
cif_restraints : pd.DataFrame
Plane restraints with 'plane_id', 'atom', 'sigma' columns.
Returns
-------
int
Number of plane restraints added.
"""
atom_names = set(residue["name"].values)
# Filter to atoms that exist in residue
usable = cif_restraints[cif_restraints["atom"].isin(atom_names)]
if len(usable) == 0:
return 0
name_to_idx = self._build_name_to_index_map(residue)
if name_to_idx is None:
return 0
count = 0
# Process each plane separately
for plane_id in usable["plane_id"].unique():
plane_atoms = usable[usable["plane_id"] == plane_id]
atom_count = len(plane_atoms)
# Skip invalid planes (fewer than 3 atoms)
if atom_count < 3:
continue
# Get indices and sigmas for this plane
atom_indices = np.array(
[name_to_idx[name] for name in plane_atoms["atom"].values]
)
sigmas = plane_atoms["sigma"].values.astype(np.float32)
# Store by atom count
if atom_count not in self._planes_by_count:
self._planes_by_count[atom_count] = {"indices": [], "sigmas": []}
self._planes_by_count[atom_count]["indices"].append(atom_indices)
self._planes_by_count[atom_count]["sigmas"].append(sigmas)
count += 1
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""
Convert accumulated data to tensors grouped by atom count.
Parameters
----------
device : torch.device
Target device for tensors.
sort_indices : bool, default True
Whether to sort by first atom index.
min_sigma : float, default 1e-4
Minimum sigma value.
Returns
-------
dict or None
Dictionary with keys like '4_atoms', '6_atoms', each containing
'indices' and 'sigmas' tensors.
"""
if not self._planes_by_count:
return None
result = {}
for atom_count, data in self._planes_by_count.items():
if not data["indices"]:
continue
indices = np.stack(data["indices"], axis=0)
sigmas = np.stack(data["sigmas"], axis=0)
# Sort by first atom index
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
sigmas = sigmas[sort_order]
# Replace zero sigmas
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
key = f"{atom_count}_atoms"
result[key] = {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
}
return result if result else None
[docs]
class ChiralRestraintBuilder(RestraintBuilder):
"""
Builder for chiral volume restraints.
Chiral restraints define the expected handedness (R/S configuration)
of tetrahedral centers. Each restraint specifies a center atom and
three neighbors, with a signed ideal volume.
Attributes
----------
_ideal_volumes : list of np.ndarray
Accumulated ideal volume arrays (signed based on chirality).
ideal_volume : float
Magnitude of ideal chiral volume in ų.
sigma : float
Standard deviation for restraint in ų.
Examples
--------
::
builder = ChiralRestraintBuilder(ideal_volume=2.5, sigma=0.2)
for residue in residues:
builder.process_residue(residue, cif_chirals)
result = builder.finalize(device)
print(result['indices'].shape) # (n_chirals, 4)
print(result['ideal_volumes'].shape) # (n_chirals,)
"""
restraint_type = "chiral"
atom_columns = ["atom_centre", "atom1", "atom2", "atom3"]
n_atoms = 4
[docs]
def __init__(self, ideal_volume: float = 2.5, sigma: float = 0.2, verbose: int = 0):
"""
Initialize with chiral volume parameters.
Parameters
----------
ideal_volume : float, default 2.5
Magnitude of ideal chiral volume in ų.
sigma : float, default 0.2
Standard deviation for restraint in ų.
verbose : int, default 0
Verbosity level.
"""
super().__init__(verbose)
self.ideal_volume = ideal_volume
self.sigma = sigma
self._ideal_volumes: List[np.ndarray] = []
[docs]
def reset(self):
"""Clear all accumulated data."""
super().reset()
self._ideal_volumes.clear()
[docs]
def process_residue(
self, residue: pd.DataFrame, cif_restraints: pd.DataFrame
) -> int:
"""
Process chiral restraints for a single residue.
Parameters
----------
residue : pd.DataFrame
Residue atoms with 'name' and 'index' columns.
cif_restraints : pd.DataFrame
Chiral restraints with 'atom_centre', 'atom1', 'atom2', 'atom3',
'volume_sign' columns.
Returns
-------
int
Number of chiral restraints added.
"""
if cif_restraints.empty:
return 0
atom_names = set(residue["name"].values)
name_to_idx = self._build_name_to_index_map(residue)
if name_to_idx is None:
return 0
indices_list = []
volumes_list = []
sigmas_list = []
for _, row in cif_restraints.iterrows():
center = row["atom_centre"]
atom1 = row["atom1"]
atom2 = row["atom2"]
atom3 = row["atom3"]
volume_sign = row["volume_sign"]
# Check all atoms are present
if not all(a in atom_names for a in [center, atom1, atom2, atom3]):
continue
# Get indices
try:
idx_center = name_to_idx[center]
idx1 = name_to_idx[atom1]
idx2 = name_to_idx[atom2]
idx3 = name_to_idx[atom3]
except KeyError:
continue
# Determine signed ideal volume
if volume_sign == "positive":
signed_volume = self.ideal_volume
elif volume_sign == "negative":
signed_volume = -self.ideal_volume
elif volume_sign in ["both", "either"]:
signed_volume = 0.0 # Achiral/racemic
else:
if self.verbose > 2:
print(f"WARNING: Unknown chiral sign '{volume_sign}'")
continue
indices_list.append([idx_center, idx1, idx2, idx3])
volumes_list.append(signed_volume)
sigmas_list.append(self.sigma)
if not indices_list:
return 0
self._indices.append(np.array(indices_list, dtype=np.int64))
self._ideal_volumes.append(np.array(volumes_list, dtype=np.float32))
self._sigmas.append(np.array(sigmas_list, dtype=np.float32))
self._count += len(indices_list)
return len(indices_list)
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
"""
Convert accumulated data to sorted tensors.
Parameters
----------
device : torch.device
Target device for tensors.
sort_indices : bool, default True
Whether to sort by first atom index.
min_sigma : float, default 1e-4
Minimum sigma value.
Returns
-------
dict or None
Dictionary with 'indices', 'ideal_volumes', 'sigmas' tensors.
"""
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
ideal_volumes = np.concatenate(self._ideal_volumes)
sigmas = np.concatenate(self._sigmas)
# Sort by center atom index
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
ideal_volumes = ideal_volumes[sort_order]
sigmas = sigmas[sort_order]
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"ideal_volumes": torch.tensor(
ideal_volumes, dtype=get_float_dtype(), device=device
),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
}
# =============================================================================
# Inter-Residue Restraint Builders
# =============================================================================
[docs]
class InterResidueBondBuilder:
"""
Builder for inter-residue bond restraints (peptide bonds, disulfide bonds).
Unlike intra-residue builders, this handles bonds between atoms in
different residues. It processes residue pairs and link definitions.
Parameters
----------
verbose : int, default 0
Verbosity level for debug output.
Attributes
----------
_indices : list of np.ndarray
Accumulated bond index arrays.
_references : list of np.ndarray
Accumulated reference distance arrays.
_sigmas : list of np.ndarray
Accumulated sigma arrays.
Examples
--------
::
builder = InterResidueBondBuilder()
for res_i, res_next in iterator.get_consecutive_pairs():
builder.process_peptide_bond(res_i, res_next, trans_link)
result = builder.finalize(device)
"""
[docs]
def __init__(self, verbose: int = 0):
"""Initialize empty accumulator lists."""
self.verbose = verbose
self._indices: List[np.ndarray] = []
self._references: List[np.ndarray] = []
self._sigmas: List[np.ndarray] = []
self._count: int = 0
[docs]
def reset(self):
"""Clear all accumulated data."""
self._indices.clear()
self._references.clear()
self._sigmas.clear()
self._count = 0
@staticmethod
def _get_atom_index(residue: pd.DataFrame, atom_name: str) -> Optional[int]:
"""
Get atom index from residue, handling alternate conformations.
Prefers atoms with no altloc (' '), then 'A', then first available.
Parameters
----------
residue : pd.DataFrame
Residue atoms.
atom_name : str
Name of atom to find.
Returns
-------
int or None
Atom index, or None if not found.
"""
atoms = residue[residue["name"] == atom_name]
if len(atoms) == 0:
return None
if " " in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == " "].iloc[0]["index"])
elif "A" in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == "A"].iloc[0]["index"])
else:
return int(atoms.iloc[0]["index"])
[docs]
def process_peptide_bond(
self,
residue_i: pd.DataFrame,
residue_next: pd.DataFrame,
link_bonds: pd.DataFrame,
) -> int:
"""
Process peptide bond restraints between consecutive residues.
Parameters
----------
residue_i : pd.DataFrame
First residue (C-terminal of bond).
residue_next : pd.DataFrame
Second residue (N-terminal of bond).
link_bonds : pd.DataFrame
Bond definitions from link dictionary with columns:
'atom_1_comp_id', 'atom1', 'atom_2_comp_id', 'atom2', 'value', 'sigma'.
Returns
-------
int
Number of bond restraints added.
"""
count = 0
for _, bond_row in link_bonds.iterrows():
comp1 = bond_row["atom_1_comp_id"]
comp2 = bond_row["atom_2_comp_id"]
atom1_name = bond_row["atom1"]
atom2_name = bond_row["atom2"]
# Get residue based on comp_id ('1' = residue_i, '2' = residue_next)
res1 = residue_i if comp1 == "1" else residue_next
res2 = residue_i if comp2 == "1" else residue_next
idx1 = self._get_atom_index(res1, atom1_name)
idx2 = self._get_atom_index(res2, atom2_name)
if idx1 is not None and idx2 is not None:
self._indices.append(np.array([[idx1, idx2]], dtype=np.int64))
self._references.append(
np.array([float(bond_row["value"])], dtype=np.float32)
)
self._sigmas.append(
np.array([float(bond_row["sigma"])], dtype=np.float32)
)
count += 1
self._count += count
return count
[docs]
def process_disulfide_bond(
self, sg1_idx: int, sg2_idx: int, bond_length: float, bond_sigma: float
) -> int:
"""
Process a single disulfide bond restraint.
Parameters
----------
sg1_idx : int
Index of first SG atom.
sg2_idx : int
Index of second SG atom.
bond_length : float
Target bond length in Ã….
bond_sigma : float
Sigma for restraint in Ã….
Returns
-------
int
Always returns 1.
"""
self._indices.append(np.array([[sg1_idx, sg2_idx]], dtype=np.int64))
self._references.append(np.array([bond_length], dtype=np.float32))
self._sigmas.append(np.array([bond_sigma], dtype=np.float32))
self._count += 1
return 1
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
"""
Convert accumulated data to sorted tensors.
Parameters
----------
device : torch.device
Target device for tensors.
sort_indices : bool, default True
Whether to sort by first atom index.
min_sigma : float, default 1e-4
Minimum sigma value.
Returns
-------
dict or None
Dictionary with 'indices', 'references', 'sigmas' tensors.
"""
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
}
@property
def count(self) -> int:
"""Return total number of restraints accumulated."""
return self._count
[docs]
class InterResidueAngleBuilder:
"""
Builder for inter-residue angle restraints (peptide angles).
Handles angles that span two consecutive residues.
Parameters
----------
verbose : int, default 0
Verbosity level for debug output.
"""
[docs]
def __init__(self, verbose: int = 0):
"""Initialize empty accumulator lists."""
self.verbose = verbose
self._indices: List[np.ndarray] = []
self._references: List[np.ndarray] = []
self._sigmas: List[np.ndarray] = []
self._count: int = 0
[docs]
def reset(self):
"""Clear all accumulated data."""
self._indices.clear()
self._references.clear()
self._sigmas.clear()
self._count = 0
@staticmethod
def _get_atom_index(residue: pd.DataFrame, atom_name: str) -> Optional[int]:
"""Get atom index from residue, handling alternate conformations."""
atoms = residue[residue["name"] == atom_name]
if len(atoms) == 0:
return None
if " " in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == " "].iloc[0]["index"])
elif "A" in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == "A"].iloc[0]["index"])
else:
return int(atoms.iloc[0]["index"])
[docs]
def process_peptide_angles(
self,
residue_i: pd.DataFrame,
residue_next: pd.DataFrame,
link_angles: pd.DataFrame,
) -> int:
"""
Process peptide angle restraints between consecutive residues.
Parameters
----------
residue_i : pd.DataFrame
First residue.
residue_next : pd.DataFrame
Second residue.
link_angles : pd.DataFrame
Angle definitions from link dictionary.
Returns
-------
int
Number of angle restraints added.
"""
count = 0
for _, angle_row in link_angles.iterrows():
comp1 = angle_row["atom_1_comp_id"]
comp2 = angle_row["atom_2_comp_id"]
comp3 = angle_row["atom_3_comp_id"]
atom1_name = angle_row["atom1"]
atom2_name = angle_row["atom2"]
atom3_name = angle_row["atom3"]
res1 = residue_i if comp1 == "1" else residue_next
res2 = residue_i if comp2 == "1" else residue_next
res3 = residue_i if comp3 == "1" else residue_next
idx1 = self._get_atom_index(res1, atom1_name)
idx2 = self._get_atom_index(res2, atom2_name)
idx3 = self._get_atom_index(res3, atom3_name)
if idx1 is not None and idx2 is not None and idx3 is not None:
self._indices.append(np.array([[idx1, idx2, idx3]], dtype=np.int64))
self._references.append(
np.array([float(angle_row["value"])], dtype=np.float32)
)
self._sigmas.append(
np.array([float(angle_row["sigma"])], dtype=np.float32)
)
count += 1
self._count += count
return count
[docs]
def process_disulfide_angles(
self,
res1_atoms: pd.DataFrame,
res2_atoms: pd.DataFrame,
link_angles: pd.DataFrame,
) -> int:
"""
Process disulfide angle restraints.
Parameters
----------
res1_atoms : pd.DataFrame
First cysteine residue atoms.
res2_atoms : pd.DataFrame
Second cysteine residue atoms.
link_angles : pd.DataFrame
Angle definitions from disulfide link.
Returns
-------
int
Number of angle restraints added.
"""
return self.process_peptide_angles(res1_atoms, res2_atoms, link_angles)
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
"""Convert accumulated data to sorted tensors."""
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
}
@property
def count(self) -> int:
"""Return total number of restraints accumulated."""
return self._count
[docs]
class InterResidueTorsionBuilder:
"""
Builder for inter-residue torsion restraints (phi, psi, omega).
Handles backbone torsion angles that span consecutive residues.
Separates phi, psi, and omega angles for different treatment.
Parameters
----------
verbose : int, default 0
Verbosity level for debug output.
"""
[docs]
def __init__(self, verbose: int = 0):
"""Initialize empty accumulator lists for each torsion type."""
self.verbose = verbose
# Separate accumulators for phi, psi, omega
self._phi_indices: List[np.ndarray] = []
self._phi_periods: List[np.ndarray] = []
self._psi_indices: List[np.ndarray] = []
self._psi_periods: List[np.ndarray] = []
self._omega_indices: List[np.ndarray] = []
self._omega_references: List[np.ndarray] = []
self._omega_sigmas: List[np.ndarray] = []
self._omega_periods: List[np.ndarray] = []
self._omega_is_proline: List[bool] = []
# For disulfide torsions
self._disulfide_indices: List[np.ndarray] = []
self._disulfide_references: List[np.ndarray] = []
self._disulfide_sigmas: List[np.ndarray] = []
self._disulfide_periods: List[np.ndarray] = []
[docs]
def reset(self):
"""Clear all accumulated data."""
self._phi_indices.clear()
self._phi_periods.clear()
self._psi_indices.clear()
self._psi_periods.clear()
self._omega_indices.clear()
self._omega_references.clear()
self._omega_sigmas.clear()
self._omega_periods.clear()
self._omega_is_proline.clear()
self._disulfide_indices.clear()
self._disulfide_references.clear()
self._disulfide_sigmas.clear()
self._disulfide_periods.clear()
@staticmethod
def _get_atom_index(residue: pd.DataFrame, atom_name: str) -> Optional[int]:
"""Get atom index from residue, handling alternate conformations."""
atoms = residue[residue["name"] == atom_name]
if len(atoms) == 0:
return None
if " " in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == " "].iloc[0]["index"])
elif "A" in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == "A"].iloc[0]["index"])
else:
return int(atoms.iloc[0]["index"])
[docs]
def process_peptide_torsions(
self,
residue_i: pd.DataFrame,
residue_next: pd.DataFrame,
link_torsions: pd.DataFrame,
) -> Tuple[int, int, int]:
"""
Process backbone torsion restraints between consecutive residues.
Parameters
----------
residue_i : pd.DataFrame
First residue.
residue_next : pd.DataFrame
Second residue.
link_torsions : pd.DataFrame
Torsion definitions from link dictionary with 'id' column
indicating 'phi', 'psi', or 'omega'.
Returns
-------
tuple of (int, int, int)
Number of (phi, psi, omega) torsions added.
"""
resname_next = residue_next["resname"].iloc[0]
is_proline = resname_next == "PRO"
phi_count = 0
psi_count = 0
omega_count = 0
for _, torsion_row in link_torsions.iterrows():
comp1 = torsion_row["atom_1_comp_id"]
comp2 = torsion_row["atom_2_comp_id"]
comp3 = torsion_row["atom_3_comp_id"]
comp4 = torsion_row["atom_4_comp_id"]
atom1_name = torsion_row["atom1"]
atom2_name = torsion_row["atom2"]
atom3_name = torsion_row["atom3"]
atom4_name = torsion_row["atom4"]
torsion_id = torsion_row["id"]
res1 = residue_i if comp1 == "1" else residue_next
res2 = residue_i if comp2 == "1" else residue_next
res3 = residue_i if comp3 == "1" else residue_next
res4 = residue_i if comp4 == "1" else residue_next
idx1 = self._get_atom_index(res1, atom1_name)
idx2 = self._get_atom_index(res2, atom2_name)
idx3 = self._get_atom_index(res3, atom3_name)
idx4 = self._get_atom_index(res4, atom4_name)
if idx1 is None or idx2 is None or idx3 is None or idx4 is None:
continue
period = (
int(torsion_row["period"])
if "period" in torsion_row and pd.notna(torsion_row["period"])
else 0
)
if torsion_id == "phi":
self._phi_indices.append(
np.array([[idx1, idx2, idx3, idx4]], dtype=np.int64)
)
self._phi_periods.append(np.array([period], dtype=np.int64))
phi_count += 1
elif torsion_id == "psi":
self._psi_indices.append(
np.array([[idx1, idx2, idx3, idx4]], dtype=np.int64)
)
self._psi_periods.append(np.array([period], dtype=np.int64))
psi_count += 1
elif torsion_id == "omega":
self._omega_indices.append(
np.array([[idx1, idx2, idx3, idx4]], dtype=np.int64)
)
self._omega_references.append(
np.array([float(torsion_row["value"])], dtype=np.float32)
)
self._omega_sigmas.append(
np.array([float(torsion_row["sigma"])], dtype=np.float32)
)
self._omega_periods.append(np.array([period], dtype=np.int64))
self._omega_is_proline.append(is_proline)
omega_count += 1
return phi_count, psi_count, omega_count
[docs]
def process_disulfide_torsions(
self,
res1_atoms: pd.DataFrame,
res2_atoms: pd.DataFrame,
link_torsions: pd.DataFrame,
) -> int:
"""
Process disulfide torsion restraints.
Parameters
----------
res1_atoms : pd.DataFrame
First cysteine residue atoms.
res2_atoms : pd.DataFrame
Second cysteine residue atoms.
link_torsions : pd.DataFrame
Torsion definitions from disulfide link.
Returns
-------
int
Number of torsion restraints added.
"""
count = 0
for _, torsion_row in link_torsions.iterrows():
comp1 = torsion_row["atom_1_comp_id"]
comp2 = torsion_row["atom_2_comp_id"]
comp3 = torsion_row["atom_3_comp_id"]
comp4 = torsion_row["atom_4_comp_id"]
atom1_name = torsion_row["atom1"]
atom2_name = torsion_row["atom2"]
atom3_name = torsion_row["atom3"]
atom4_name = torsion_row["atom4"]
res1 = res1_atoms if comp1 == "1" else res2_atoms
res2 = res1_atoms if comp2 == "1" else res2_atoms
res3 = res1_atoms if comp3 == "1" else res2_atoms
res4 = res1_atoms if comp4 == "1" else res2_atoms
idx1 = self._get_atom_index(res1, atom1_name)
idx2 = self._get_atom_index(res2, atom2_name)
idx3 = self._get_atom_index(res3, atom3_name)
idx4 = self._get_atom_index(res4, atom4_name)
if idx1 is None or idx2 is None or idx3 is None or idx4 is None:
continue
self._disulfide_indices.append(
np.array([[idx1, idx2, idx3, idx4]], dtype=np.int64)
)
self._disulfide_references.append(
np.array([float(torsion_row["value"])], dtype=np.float32)
)
self._disulfide_sigmas.append(
np.array([float(torsion_row["sigma"])], dtype=np.float32)
)
self._disulfide_periods.append(
np.array([2], dtype=np.int64)
) # Period 2 for disulfide
count += 1
return count
[docs]
def finalize_phi(
self, device: torch.device, sort_indices: bool = True
) -> Optional[Dict[str, torch.Tensor]]:
"""Finalize phi angle indices and periods."""
if not self._phi_indices:
return None
indices = np.concatenate(self._phi_indices, axis=0)
periods = np.concatenate(self._phi_periods)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
periods = periods[sort_order]
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"periods": torch.tensor(periods, dtype=torch.long, device=device),
}
[docs]
def finalize_psi(
self, device: torch.device, sort_indices: bool = True
) -> Optional[Dict[str, torch.Tensor]]:
"""Finalize psi angle indices and periods."""
if not self._psi_indices:
return None
indices = np.concatenate(self._psi_indices, axis=0)
periods = np.concatenate(self._psi_periods)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
periods = periods[sort_order]
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"periods": torch.tensor(periods, dtype=torch.long, device=device),
}
[docs]
def finalize_omega(
self, device: torch.device, sort_indices: bool = True
) -> Optional[Dict[str, torch.Tensor]]:
"""Finalize omega angle restraints."""
if not self._omega_indices:
return None
indices = np.concatenate(self._omega_indices, axis=0)
references = np.concatenate(self._omega_references)
sigmas = np.concatenate(self._omega_sigmas)
periods = np.concatenate(self._omega_periods)
is_proline = np.array(self._omega_is_proline, dtype=bool)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
periods = periods[sort_order]
is_proline = is_proline[sort_order]
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
"periods": torch.tensor(periods, dtype=torch.long, device=device),
"is_proline": torch.tensor(is_proline, dtype=torch.bool, device=device),
}
[docs]
def finalize_disulfide(
self, device: torch.device, sort_indices: bool = True
) -> Optional[Dict[str, torch.Tensor]]:
"""Finalize disulfide torsion restraints."""
if not self._disulfide_indices:
return None
indices = np.concatenate(self._disulfide_indices, axis=0)
references = np.concatenate(self._disulfide_references)
sigmas = np.concatenate(self._disulfide_sigmas)
periods = np.concatenate(self._disulfide_periods)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
periods = periods[sort_order]
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=get_float_dtype(), device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
"periods": torch.tensor(periods, dtype=torch.long, device=device),
}
[docs]
class InterResiduePlaneBuilder:
"""
Builder for inter-residue plane restraints (peptide planes).
Handles planes that span consecutive residues (e.g., the planar
peptide bond group).
Parameters
----------
verbose : int, default 0
Verbosity level for debug output.
"""
[docs]
def __init__(self, verbose: int = 0):
"""Initialize with planes-by-count dictionary."""
self.verbose = verbose
self._planes_by_count: Dict[int, Dict[str, List]] = {}
self._count: int = 0
[docs]
def reset(self):
"""Clear all accumulated data."""
self._planes_by_count.clear()
self._count = 0
@staticmethod
def _get_atom_index(residue: pd.DataFrame, atom_name: str) -> Optional[int]:
"""Get atom index from residue, handling alternate conformations."""
atoms = residue[residue["name"] == atom_name]
if len(atoms) == 0:
return None
if " " in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == " "].iloc[0]["index"])
elif "A" in atoms["altloc"].values:
return int(atoms[atoms["altloc"] == "A"].iloc[0]["index"])
else:
return int(atoms.iloc[0]["index"])
[docs]
def process_peptide_planes(
self,
residue_i: pd.DataFrame,
residue_next: pd.DataFrame,
link_planes: pd.DataFrame,
) -> int:
"""
Process peptide plane restraints between consecutive residues.
Parameters
----------
residue_i : pd.DataFrame
First residue.
residue_next : pd.DataFrame
Second residue.
link_planes : pd.DataFrame
Plane definitions from link dictionary with 'plane_id',
'atom_comp_id', 'atom', 'sigma' columns.
Returns
-------
int
Number of plane restraints added.
"""
count = 0
# Group by plane_id
for plane_id in link_planes["plane_id"].unique():
plane_atoms = link_planes[link_planes["plane_id"] == plane_id]
# Collect atom indices for this plane
atom_indices = []
sigmas = []
all_found = True
for _, plane_atom_row in plane_atoms.iterrows():
comp_id = plane_atom_row["atom_comp_id"]
atom_name = plane_atom_row["atom"]
sigma = float(plane_atom_row["sigma"])
residue = residue_i if comp_id == "1" else residue_next
atom_idx = self._get_atom_index(residue, atom_name)
if atom_idx is None:
all_found = False
break
atom_indices.append(atom_idx)
sigmas.append(sigma)
if not all_found or len(atom_indices) < 3:
continue
atom_count = len(atom_indices)
if atom_count not in self._planes_by_count:
self._planes_by_count[atom_count] = {"indices": [], "sigmas": []}
self._planes_by_count[atom_count]["indices"].append(
np.array(atom_indices, dtype=np.int64)
)
self._planes_by_count[atom_count]["sigmas"].append(
np.array(sigmas, dtype=np.float32)
)
count += 1
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""Convert accumulated data to tensors grouped by atom count."""
if not self._planes_by_count:
return None
result = {}
for atom_count, data in self._planes_by_count.items():
if not data["indices"]:
continue
indices = np.stack(data["indices"], axis=0)
sigmas = np.stack(data["sigmas"], axis=0)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
key = f"{atom_count}_atoms"
result[key] = {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"sigmas": torch.tensor(sigmas, dtype=get_float_dtype(), device=device),
}
return result if result else None
@property
def count(self) -> int:
"""Return total number of restraints accumulated."""
return self._count