"""
Fast Restraint Builder Classes using NumPy and Numba
This module provides optimized builder classes that avoid Pandas operations
in the hot loop.
"""
from typing import Any, Dict, Iterator, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
try:
import numba
from numba import njit, prange
HAS_NUMBA = True
except ImportError:
HAS_NUMBA = False
# Fallback decorator that does nothing
def njit(*args, **kwargs):
def decorator(func):
return func
if len(args) == 1 and callable(args[0]):
return args[0]
return decorator
prange = range
# =============================================================================
# Numba-accelerated helper functions
# =============================================================================
[docs]
@njit(cache=True)
def find_atom_index(atom_names: np.ndarray, target: str) -> int:
"""
Find index of target atom name in array.
Returns -1 if not found.
"""
for i in range(len(atom_names)):
if atom_names[i] == target:
return i
return -1
[docs]
@njit(cache=True)
def match_bonds_numba(
residue_atom_names: np.ndarray, # atom names for this residue
residue_atom_indices: np.ndarray, # global atom indices
bond_atom1: np.ndarray, # CIF bond atom1 names
bond_atom2: np.ndarray, # CIF bond atom2 names
bond_values: np.ndarray, # CIF bond reference values
bond_sigmas: np.ndarray, # CIF bond sigmas
out_idx1: np.ndarray, # output arrays (pre-allocated)
out_idx2: np.ndarray,
out_refs: np.ndarray,
out_sigmas: np.ndarray,
) -> int:
"""
Match bond restraints for a single residue.
Returns number of matched bonds.
"""
count = 0
n_bonds = len(bond_atom1)
n_atoms = len(residue_atom_names)
for i in range(n_bonds):
# Find atom1
idx1 = -1
for j in range(n_atoms):
if residue_atom_names[j] == bond_atom1[i]:
idx1 = j
break
if idx1 < 0:
continue
# Find atom2
idx2 = -1
for j in range(n_atoms):
if residue_atom_names[j] == bond_atom2[i]:
idx2 = j
break
if idx2 < 0:
continue
# Both atoms found - add restraint
out_idx1[count] = residue_atom_indices[idx1]
out_idx2[count] = residue_atom_indices[idx2]
out_refs[count] = bond_values[i]
out_sigmas[count] = bond_sigmas[i]
count += 1
return count
[docs]
@njit(cache=True)
def match_angles_numba(
residue_atom_names: np.ndarray,
residue_atom_indices: np.ndarray,
angle_atom1: np.ndarray,
angle_atom2: np.ndarray,
angle_atom3: np.ndarray,
angle_values: np.ndarray,
angle_sigmas: np.ndarray,
out_idx1: np.ndarray,
out_idx2: np.ndarray,
out_idx3: np.ndarray,
out_refs: np.ndarray,
out_sigmas: np.ndarray,
) -> int:
"""Match angle restraints for a single residue."""
count = 0
n_angles = len(angle_atom1)
n_atoms = len(residue_atom_names)
for i in range(n_angles):
# Find all three atoms
idx1, idx2, idx3 = -1, -1, -1
for j in range(n_atoms):
name = residue_atom_names[j]
if name == angle_atom1[i]:
idx1 = j
elif name == angle_atom2[i]:
idx2 = j
elif name == angle_atom3[i]:
idx3 = j
if idx1 < 0 or idx2 < 0 or idx3 < 0:
continue
out_idx1[count] = residue_atom_indices[idx1]
out_idx2[count] = residue_atom_indices[idx2]
out_idx3[count] = residue_atom_indices[idx3]
out_refs[count] = angle_values[i]
out_sigmas[count] = angle_sigmas[i]
count += 1
return count
[docs]
@njit(cache=True)
def match_torsions_numba(
residue_atom_names: np.ndarray,
residue_atom_indices: np.ndarray,
torsion_atom1: np.ndarray,
torsion_atom2: np.ndarray,
torsion_atom3: np.ndarray,
torsion_atom4: np.ndarray,
torsion_values: np.ndarray,
torsion_sigmas: np.ndarray,
torsion_periods: np.ndarray,
out_idx1: np.ndarray,
out_idx2: np.ndarray,
out_idx3: np.ndarray,
out_idx4: np.ndarray,
out_refs: np.ndarray,
out_sigmas: np.ndarray,
out_periods: np.ndarray,
) -> int:
"""Match torsion restraints for a single residue."""
count = 0
n_torsions = len(torsion_atom1)
n_atoms = len(residue_atom_names)
for i in range(n_torsions):
# Skip if sigma is zero
if torsion_sigmas[i] == 0:
continue
# Find all four atoms
idx1, idx2, idx3, idx4 = -1, -1, -1, -1
for j in range(n_atoms):
name = residue_atom_names[j]
if name == torsion_atom1[i]:
idx1 = j
elif name == torsion_atom2[i]:
idx2 = j
elif name == torsion_atom3[i]:
idx3 = j
elif name == torsion_atom4[i]:
idx4 = j
if idx1 < 0 or idx2 < 0 or idx3 < 0 or idx4 < 0:
continue
out_idx1[count] = residue_atom_indices[idx1]
out_idx2[count] = residue_atom_indices[idx2]
out_idx3[count] = residue_atom_indices[idx3]
out_idx4[count] = residue_atom_indices[idx4]
out_refs[count] = torsion_values[i]
out_sigmas[count] = torsion_sigmas[i]
out_periods[count] = torsion_periods[i]
count += 1
return count
[docs]
@njit(cache=True)
def match_chirals_numba(
residue_atom_names: np.ndarray,
residue_atom_indices: np.ndarray,
chiral_center: np.ndarray,
chiral_atom1: np.ndarray,
chiral_atom2: np.ndarray,
chiral_atom3: np.ndarray,
chiral_volume_signs: np.ndarray, # +1, -1, 0, or NaN
chiral_sigmas: np.ndarray,
out_center: np.ndarray,
out_idx1: np.ndarray,
out_idx2: np.ndarray,
out_idx3: np.ndarray,
out_signs: np.ndarray,
out_sigmas: np.ndarray,
) -> int:
"""Match chiral restraints for a single residue."""
count = 0
n_chirals = len(chiral_center)
n_atoms = len(residue_atom_names)
for i in range(n_chirals):
# Skip if unknown chiral sign (NaN)
if np.isnan(chiral_volume_signs[i]):
continue
# Find all four atoms (center + 3 neighbors)
idx_c, idx1, idx2, idx3 = -1, -1, -1, -1
for j in range(n_atoms):
name = residue_atom_names[j]
if name == chiral_center[i]:
idx_c = j
elif name == chiral_atom1[i]:
idx1 = j
elif name == chiral_atom2[i]:
idx2 = j
elif name == chiral_atom3[i]:
idx3 = j
if idx_c < 0 or idx1 < 0 or idx2 < 0 or idx3 < 0:
continue
out_center[count] = residue_atom_indices[idx_c]
out_idx1[count] = residue_atom_indices[idx1]
out_idx2[count] = residue_atom_indices[idx2]
out_idx3[count] = residue_atom_indices[idx3]
out_signs[count] = chiral_volume_signs[i]
out_sigmas[count] = chiral_sigmas[i]
count += 1
return count
# =============================================================================
# Pre-processed data structures
# =============================================================================
[docs]
class PreprocessedPDB:
"""
Pre-processed PDB data as NumPy arrays for fast access.
Converts DataFrame operations to array lookups.
"""
[docs]
def __init__(self, pdb: pd.DataFrame):
"""
Initialize from PDB DataFrame.
Parameters
----------
pdb : pd.DataFrame
PDB DataFrame with standard columns.
"""
self.n_atoms = len(pdb)
# Core arrays
self.atom_names = pdb["name"].values.astype(str)
self.atom_indices = pdb["index"].values.astype(np.int64)
self.chain_ids = pdb["chainid"].values.astype(str)
self.resseqs = pdb["resseq"].values.astype(np.int64)
self.resnames = pdb["resname"].values.astype(str)
# Optional
if "altloc" in pdb.columns:
self.altlocs = pdb["altloc"].values.astype(str)
else:
self.altlocs = np.full(self.n_atoms, " ", dtype=str)
if "ATOM" in pdb.columns:
self.atom_types = pdb["ATOM"].values.astype(str)
else:
self.atom_types = np.full(self.n_atoms, "ATOM", dtype=str)
# Pre-compute residue boundaries
self._compute_residue_boundaries()
def _compute_residue_boundaries(self):
"""Compute start/end indices for each residue."""
# Create residue key for grouping
residue_keys = np.char.add(
np.char.add(self.chain_ids, "_"), self.resseqs.astype(str)
)
# Find unique residues and their boundaries
unique_keys, first_indices = np.unique(residue_keys, return_index=True)
# Sort by first occurrence to maintain order
order = np.argsort(first_indices)
self.residue_keys = unique_keys[order]
# Compute boundaries
sorted_first = first_indices[order]
self.residue_starts = sorted_first
self.residue_ends = np.append(sorted_first[1:], self.n_atoms)
self.n_residues = len(self.residue_keys)
# Store residue metadata
self.residue_chain_ids = self.chain_ids[self.residue_starts]
self.residue_resseqs = self.resseqs[self.residue_starts]
self.residue_resnames = self.resnames[self.residue_starts]
[docs]
def get_residue_atoms(self, residue_idx: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Get atom names and indices for a residue.
Parameters
----------
residue_idx : int
Index into residue arrays.
Returns
-------
atom_names : np.ndarray
Atom names for this residue.
atom_indices : np.ndarray
Global atom indices.
"""
start = self.residue_starts[residue_idx]
end = self.residue_ends[residue_idx]
return self.atom_names[start:end], self.atom_indices[start:end]
[docs]
class PreprocessedCIF:
"""
Pre-processed CIF restraints as NumPy arrays.
Stores restraint data per residue type for fast lookup.
"""
[docs]
def __init__(self, cif_dict: Dict):
"""
Initialize from CIF dictionary.
Parameters
----------
cif_dict : dict
CIF dictionary with restraints per residue type.
"""
self.residue_types = list(cif_dict.keys())
# Pre-process each restraint type for each residue
self.bonds = {}
self.angles = {}
self.torsions = {}
self.planes = {}
self.chirals = {}
for restype, data in cif_dict.items():
if "bonds" in data:
self.bonds[restype] = self._preprocess_bonds(data["bonds"])
if "angles" in data:
self.angles[restype] = self._preprocess_angles(data["angles"])
if "torsions" in data:
self.torsions[restype] = self._preprocess_torsions(data["torsions"])
if "planes" in data:
self.planes[restype] = self._preprocess_planes(data["planes"])
if "chirals" in data:
self.chirals[restype] = self._preprocess_chirals(data["chirals"])
def _preprocess_bonds(self, bonds_df: pd.DataFrame) -> Dict[str, np.ndarray]:
"""Convert bonds DataFrame to NumPy arrays."""
return {
"atom1": bonds_df["atom1"].values.astype(str),
"atom2": bonds_df["atom2"].values.astype(str),
"value": bonds_df["value"].values.astype(np.float64),
"sigma": bonds_df["sigma"].values.astype(np.float64),
}
def _preprocess_angles(self, angles_df: pd.DataFrame) -> Dict[str, np.ndarray]:
"""Convert angles DataFrame to NumPy arrays."""
return {
"atom1": angles_df["atom1"].values.astype(str),
"atom2": angles_df["atom2"].values.astype(str),
"atom3": angles_df["atom3"].values.astype(str),
"value": angles_df["value"].values.astype(np.float64),
"sigma": angles_df["sigma"].values.astype(np.float64),
}
def _preprocess_torsions(self, torsions_df: pd.DataFrame) -> Dict[str, np.ndarray]:
"""Convert torsions DataFrame to NumPy arrays."""
# Handle both 'period' and 'periodicity' column names
if "periodicity" in torsions_df.columns:
periods = torsions_df["periodicity"].values.astype(np.int64)
elif "period" in torsions_df.columns:
periods = torsions_df["period"].values.astype(np.int64)
else:
periods = np.ones(len(torsions_df), dtype=np.int64)
return {
"atom1": torsions_df["atom1"].values.astype(str),
"atom2": torsions_df["atom2"].values.astype(str),
"atom3": torsions_df["atom3"].values.astype(str),
"atom4": torsions_df["atom4"].values.astype(str),
"value": torsions_df["value"].values.astype(np.float64),
"sigma": torsions_df["sigma"].values.astype(np.float64),
"period": periods,
}
def _preprocess_planes(self, planes_df: pd.DataFrame) -> Dict[str, Any]:
"""Convert planes DataFrame to structured data."""
# Planes are more complex - group by plane_id
plane_ids = planes_df["plane_id"].unique()
planes_data = []
for plane_id in plane_ids:
plane_atoms = planes_df[planes_df["plane_id"] == plane_id]
planes_data.append(
{
"atoms": plane_atoms["atom"].values.astype(str),
"sigmas": (
plane_atoms["sigma"].values.astype(np.float64)
if "sigma" in plane_atoms.columns
else np.full(len(plane_atoms), 0.02)
),
}
)
return planes_data
def _preprocess_chirals(self, chirals_df: pd.DataFrame) -> Dict[str, np.ndarray]:
"""Convert chirals DataFrame to NumPy arrays."""
# Use string volume_sign to determine sign
volume_signs = []
for sign in chirals_df["volume_sign"].values:
if sign == "positive":
volume_signs.append(1.0)
elif sign == "negative":
volume_signs.append(-1.0)
elif sign in ["both", "either"]:
volume_signs.append(0.0) # Achiral/racemic
else:
volume_signs.append(np.nan) # Will be filtered out
return {
"center": chirals_df["atom_centre"].values.astype(str),
"atom1": chirals_df["atom1"].values.astype(str),
"atom2": chirals_df["atom2"].values.astype(str),
"atom3": chirals_df["atom3"].values.astype(str),
"volume_sign": np.array(volume_signs, dtype=np.float64),
"sigma": (
chirals_df["sigma"].values.astype(np.float64)
if "sigma" in chirals_df.columns
else np.full(len(chirals_df), 0.2)
),
}
# =============================================================================
# Fast Builder Classes
# =============================================================================
[docs]
class FastResidueIterator:
"""
Fast residue iterator using pre-processed PDB data.
"""
[docs]
def __init__(self, preprocessed: PreprocessedPDB, filter_protein: bool = False):
"""
Initialize from preprocessed PDB.
Parameters
----------
preprocessed : PreprocessedPDB
Pre-processed PDB data.
filter_protein : bool
If True, only iterate over ATOM records.
"""
self.pdb = preprocessed
self.filter_protein = filter_protein
if filter_protein:
# Create mask for protein atoms
self.protein_mask = preprocessed.atom_types == "ATOM"
else:
self.protein_mask = None
[docs]
def __iter__(self) -> Iterator[Tuple[int, str, int, str, np.ndarray, np.ndarray]]:
"""
Iterate over residues.
Yields
------
residue_idx : int
Index of residue.
chain_id : str
Chain identifier.
resseq : int
Residue sequence number.
resname : str
Residue name.
atom_names : np.ndarray
Atom names in this residue.
atom_indices : np.ndarray
Global atom indices.
"""
for i in range(self.pdb.n_residues):
atom_names, atom_indices = self.pdb.get_residue_atoms(i)
if self.filter_protein and self.protein_mask is not None:
start = self.pdb.residue_starts[i]
end = self.pdb.residue_ends[i]
mask = self.protein_mask[start:end]
atom_names = atom_names[mask]
atom_indices = atom_indices[mask]
if len(atom_names) == 0:
continue
yield (
i,
self.pdb.residue_chain_ids[i],
self.pdb.residue_resseqs[i],
self.pdb.residue_resnames[i],
atom_names,
atom_indices,
)
def __len__(self) -> int:
return self.pdb.n_residues
[docs]
class FastBondBuilder:
"""
Fast bond restraint builder using Numba.
"""
[docs]
def __init__(self, verbose: int = 0):
self.verbose = verbose
self._indices = []
self._references = []
self._sigmas = []
self._count = 0
# Pre-allocate work arrays (will grow if needed)
self._max_bonds = 50 # per residue
self._work_idx1 = np.zeros(self._max_bonds, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_bonds, dtype=np.int64)
self._work_refs = np.zeros(self._max_bonds, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_bonds, dtype=np.float64)
[docs]
def process_residue(
self,
atom_names: np.ndarray,
atom_indices: np.ndarray,
cif_bonds: Dict[str, np.ndarray],
) -> int:
"""
Process bonds for a single residue.
Parameters
----------
atom_names : np.ndarray
Atom names in this residue.
atom_indices : np.ndarray
Global atom indices.
cif_bonds : dict
Pre-processed bond data from CIF.
Returns
-------
int
Number of bonds added.
"""
# Skip if duplicate atom names (altlocs not expanded)
if len(atom_names) != len(set(atom_names)):
return 0
n_cif_bonds = len(cif_bonds["atom1"])
# Ensure work arrays are large enough
if n_cif_bonds > self._max_bonds:
self._max_bonds = n_cif_bonds * 2
self._work_idx1 = np.zeros(self._max_bonds, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_bonds, dtype=np.int64)
self._work_refs = np.zeros(self._max_bonds, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_bonds, dtype=np.float64)
# Use Numba-accelerated matching
count = match_bonds_numba(
atom_names,
atom_indices,
cif_bonds["atom1"],
cif_bonds["atom2"],
cif_bonds["value"],
cif_bonds["sigma"],
self._work_idx1,
self._work_idx2,
self._work_refs,
self._work_sigmas,
)
if count > 0:
# Collect results
indices = np.column_stack(
[self._work_idx1[:count].copy(), self._work_idx2[:count].copy()]
)
self._indices.append(indices)
self._references.append(self._work_refs[:count].copy())
self._sigmas.append(self._work_sigmas[:count].copy())
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
"""Convert accumulated data to tensors."""
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=torch.float32, device=device),
"sigmas": torch.tensor(sigmas, dtype=torch.float32, device=device),
}
[docs]
def reset(self):
self._indices.clear()
self._references.clear()
self._sigmas.clear()
self._count = 0
[docs]
class FastAngleBuilder:
"""Fast angle restraint builder using Numba."""
[docs]
def __init__(self, verbose: int = 0):
self.verbose = verbose
self._indices = []
self._references = []
self._sigmas = []
self._count = 0
self._max_angles = 100
self._work_idx1 = np.zeros(self._max_angles, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_angles, dtype=np.int64)
self._work_idx3 = np.zeros(self._max_angles, dtype=np.int64)
self._work_refs = np.zeros(self._max_angles, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_angles, dtype=np.float64)
[docs]
def process_residue(
self,
atom_names: np.ndarray,
atom_indices: np.ndarray,
cif_angles: Dict[str, np.ndarray],
) -> int:
# Skip if duplicate atom names (altlocs not expanded)
if len(atom_names) != len(set(atom_names)):
return 0
n_cif_angles = len(cif_angles["atom1"])
if n_cif_angles > self._max_angles:
self._max_angles = n_cif_angles * 2
self._work_idx1 = np.zeros(self._max_angles, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_angles, dtype=np.int64)
self._work_idx3 = np.zeros(self._max_angles, dtype=np.int64)
self._work_refs = np.zeros(self._max_angles, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_angles, dtype=np.float64)
count = match_angles_numba(
atom_names,
atom_indices,
cif_angles["atom1"],
cif_angles["atom2"],
cif_angles["atom3"],
cif_angles["value"],
cif_angles["sigma"],
self._work_idx1,
self._work_idx2,
self._work_idx3,
self._work_refs,
self._work_sigmas,
)
if count > 0:
indices = np.column_stack(
[
self._work_idx1[:count].copy(),
self._work_idx2[:count].copy(),
self._work_idx3[:count].copy(),
]
)
self._indices.append(indices)
self._references.append(self._work_refs[:count].copy())
self._sigmas.append(self._work_sigmas[:count].copy())
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=torch.float32, device=device),
"sigmas": torch.tensor(sigmas, dtype=torch.float32, device=device),
}
[docs]
def reset(self):
self._indices.clear()
self._references.clear()
self._sigmas.clear()
self._count = 0
[docs]
class FastTorsionBuilder:
"""Fast torsion restraint builder using Numba."""
[docs]
def __init__(self, verbose: int = 0):
self.verbose = verbose
self._indices = []
self._references = []
self._sigmas = []
self._periods = []
self._count = 0
self._max_torsions = 50
self._work_idx1 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_idx3 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_idx4 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_refs = np.zeros(self._max_torsions, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_torsions, dtype=np.float64)
self._work_periods = np.zeros(self._max_torsions, dtype=np.int64)
[docs]
def process_residue(
self,
atom_names: np.ndarray,
atom_indices: np.ndarray,
cif_torsions: Dict[str, np.ndarray],
) -> int:
# Skip if duplicate atom names (altlocs not expanded)
if len(atom_names) != len(set(atom_names)):
return 0
n_cif_torsions = len(cif_torsions["atom1"])
if n_cif_torsions > self._max_torsions:
self._max_torsions = n_cif_torsions * 2
self._work_idx1 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_idx3 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_idx4 = np.zeros(self._max_torsions, dtype=np.int64)
self._work_refs = np.zeros(self._max_torsions, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_torsions, dtype=np.float64)
self._work_periods = np.zeros(self._max_torsions, dtype=np.int64)
count = match_torsions_numba(
atom_names,
atom_indices,
cif_torsions["atom1"],
cif_torsions["atom2"],
cif_torsions["atom3"],
cif_torsions["atom4"],
cif_torsions["value"],
cif_torsions["sigma"],
cif_torsions["period"],
self._work_idx1,
self._work_idx2,
self._work_idx3,
self._work_idx4,
self._work_refs,
self._work_sigmas,
self._work_periods,
)
if count > 0:
indices = np.column_stack(
[
self._work_idx1[:count].copy(),
self._work_idx2[:count].copy(),
self._work_idx3[:count].copy(),
self._work_idx4[:count].copy(),
]
)
self._indices.append(indices)
self._references.append(self._work_refs[:count].copy())
self._sigmas.append(self._work_sigmas[:count].copy())
self._periods.append(self._work_periods[:count].copy())
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
references = np.concatenate(self._references)
sigmas = np.concatenate(self._sigmas)
periods = np.concatenate(self._periods)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
references = references[sort_order]
sigmas = sigmas[sort_order]
periods = periods[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"references": torch.tensor(references, dtype=torch.float32, device=device),
"sigmas": torch.tensor(sigmas, dtype=torch.float32, device=device),
"periods": torch.tensor(periods, dtype=torch.long, device=device),
}
[docs]
def reset(self):
self._indices.clear()
self._references.clear()
self._sigmas.clear()
self._periods.clear()
self._count = 0
[docs]
class FastPlaneBuilder:
"""Fast plane restraint builder."""
[docs]
def __init__(self, verbose: int = 0):
self.verbose = verbose
# Group planes by number of atoms - store (indices, sigmas_array) tuples
self._planes_by_size: Dict[int, List[Tuple[np.ndarray, np.ndarray]]] = {}
self._count = 0
[docs]
def process_residue(
self, atom_names: np.ndarray, atom_indices: np.ndarray, cif_planes: List[Dict]
) -> int:
"""Process plane restraints for a residue."""
count = 0
# Check for duplicate atom names (skip if found)
if len(atom_names) != len(set(atom_names)):
return 0
# Build name to index map
name_to_idx = {name: idx for name, idx in zip(atom_names, atom_indices)}
for plane_data in cif_planes:
plane_atom_names = plane_data["atoms"]
plane_sigmas = plane_data["sigmas"]
# Find all atoms in this plane
plane_indices = []
plane_sigma_values = []
for i, atom_name in enumerate(plane_atom_names):
if atom_name in name_to_idx:
plane_indices.append(name_to_idx[atom_name])
plane_sigma_values.append(plane_sigmas[i])
# Need at least 3 atoms for a plane
if len(plane_indices) >= 3:
n_atoms = len(plane_indices)
indices_array = np.array(plane_indices, dtype=np.int64)
sigmas_array = np.array(plane_sigma_values, dtype=np.float64)
if n_atoms not in self._planes_by_size:
self._planes_by_size[n_atoms] = []
self._planes_by_size[n_atoms].append((indices_array, sigmas_array))
count += 1
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""Finalize plane restraints, grouped by atom count."""
if not self._planes_by_size:
return None
result = {}
for n_atoms, planes_list in self._planes_by_size.items():
indices = np.stack([p[0] for p in planes_list], axis=0)
sigmas = np.stack([p[1] for p in planes_list], axis=0)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
# Use same key format as original: '{n}_atoms'
key = f"{n_atoms}_atoms"
result[key] = {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"sigmas": torch.tensor(sigmas, dtype=torch.float32, device=device),
}
return result
[docs]
def reset(self):
self._planes_by_size.clear()
self._count = 0
[docs]
class FastChiralBuilder:
"""Fast chiral restraint builder using Numba."""
[docs]
def __init__(self, verbose: int = 0):
self.verbose = verbose
self._indices = []
self._ideal_volumes = []
self._sigmas = []
self._count = 0
self._max_chirals = 20
self._work_center = np.zeros(self._max_chirals, dtype=np.int64)
self._work_idx1 = np.zeros(self._max_chirals, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_chirals, dtype=np.int64)
self._work_idx3 = np.zeros(self._max_chirals, dtype=np.int64)
self._work_signs = np.zeros(self._max_chirals, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_chirals, dtype=np.float64)
[docs]
def process_residue(
self,
atom_names: np.ndarray,
atom_indices: np.ndarray,
cif_chirals: Dict[str, np.ndarray],
) -> int:
# Skip if duplicate atom names (altlocs not expanded)
if len(atom_names) != len(set(atom_names)):
return 0
n_cif_chirals = len(cif_chirals["center"])
if n_cif_chirals > self._max_chirals:
self._max_chirals = n_cif_chirals * 2
self._work_center = np.zeros(self._max_chirals, dtype=np.int64)
self._work_idx1 = np.zeros(self._max_chirals, dtype=np.int64)
self._work_idx2 = np.zeros(self._max_chirals, dtype=np.int64)
self._work_idx3 = np.zeros(self._max_chirals, dtype=np.int64)
self._work_signs = np.zeros(self._max_chirals, dtype=np.float64)
self._work_sigmas = np.zeros(self._max_chirals, dtype=np.float64)
count = match_chirals_numba(
atom_names,
atom_indices,
cif_chirals["center"],
cif_chirals["atom1"],
cif_chirals["atom2"],
cif_chirals["atom3"],
cif_chirals["volume_sign"],
cif_chirals["sigma"],
self._work_center,
self._work_idx1,
self._work_idx2,
self._work_idx3,
self._work_signs,
self._work_sigmas,
)
if count > 0:
indices = np.column_stack(
[
self._work_center[:count].copy(),
self._work_idx1[:count].copy(),
self._work_idx2[:count].copy(),
self._work_idx3[:count].copy(),
]
)
# Ideal volume = sign * 2.5 (typical tetrahedral volume)
ideal_volumes = self._work_signs[:count].copy() * 2.5
self._indices.append(indices)
self._ideal_volumes.append(ideal_volumes)
self._sigmas.append(self._work_sigmas[:count].copy())
self._count += count
return count
[docs]
def finalize(
self, device: torch.device, sort_indices: bool = True, min_sigma: float = 1e-4
) -> Optional[Dict[str, torch.Tensor]]:
if not self._indices:
return None
indices = np.concatenate(self._indices, axis=0)
ideal_volumes = np.concatenate(self._ideal_volumes)
sigmas = np.concatenate(self._sigmas)
if sort_indices and len(indices) > 0:
sort_order = np.argsort(indices[:, 0])
indices = indices[sort_order]
ideal_volumes = ideal_volumes[sort_order]
sigmas = sigmas[sort_order]
sigmas = np.where(sigmas == 0, min_sigma, sigmas)
return {
"indices": torch.tensor(indices, dtype=torch.long, device=device),
"ideal_volumes": torch.tensor(
ideal_volumes, dtype=torch.float32, device=device
),
"sigmas": torch.tensor(sigmas, dtype=torch.float32, device=device),
}
[docs]
def reset(self):
self._indices.clear()
self._ideal_volumes.clear()
self._sigmas.clear()
self._count = 0
# =============================================================================
# High-level API matching original builders
# =============================================================================
[docs]
def build_all_restraints_fast(
pdb: pd.DataFrame, cif_dict: Dict, device: torch.device, verbose: int = 0
) -> Dict[str, Any]:
"""
Build all restraints using the fast implementation.
Parameters
----------
pdb : pd.DataFrame
PDB DataFrame.
cif_dict : dict
CIF dictionary with restraints.
device : torch.device
Target device.
verbose : int
Verbosity level.
Returns
-------
dict
Dictionary with all restraint tensors.
"""
# Pre-process data
preprocessed_pdb = PreprocessedPDB(pdb)
preprocessed_cif = PreprocessedCIF(cif_dict)
# Create builders
bond_builder = FastBondBuilder(verbose=verbose)
angle_builder = FastAngleBuilder(verbose=verbose)
torsion_builder = FastTorsionBuilder(verbose=verbose)
plane_builder = FastPlaneBuilder(verbose=verbose)
chiral_builder = FastChiralBuilder(verbose=verbose)
# Process all residues
iterator = FastResidueIterator(preprocessed_pdb)
for res_idx, chain_id, resseq, resname, atom_names, atom_indices in iterator:
# Bonds
if resname in preprocessed_cif.bonds:
bond_builder.process_residue(
atom_names, atom_indices, preprocessed_cif.bonds[resname]
)
# Angles
if resname in preprocessed_cif.angles:
angle_builder.process_residue(
atom_names, atom_indices, preprocessed_cif.angles[resname]
)
# Torsions
if resname in preprocessed_cif.torsions:
torsion_builder.process_residue(
atom_names, atom_indices, preprocessed_cif.torsions[resname]
)
# Planes
if resname in preprocessed_cif.planes:
plane_builder.process_residue(
atom_names, atom_indices, preprocessed_cif.planes[resname]
)
# Chirals
if resname in preprocessed_cif.chirals:
chiral_builder.process_residue(
atom_names, atom_indices, preprocessed_cif.chirals[resname]
)
# Finalize
result = {}
bond_result = bond_builder.finalize(device)
if bond_result:
result["bond"] = bond_result
angle_result = angle_builder.finalize(device)
if angle_result:
result["angle"] = angle_result
torsion_result = torsion_builder.finalize(device)
if torsion_result:
result["torsion"] = torsion_result
plane_result = plane_builder.finalize(device)
if plane_result:
result["plane"] = plane_result
chiral_result = chiral_builder.finalize(device)
if chiral_result:
result["chiral"] = chiral_result
return result