"""
Riding hydrogen topology and vectorized placement for VDW restraints.
Builds a static topology map at restraints-construction time that describes
how to generate transient hydrogen atom positions from heavy-atom coordinates.
At each VDW evaluation the ``place_riding_hydrogens`` function produces H
positions in a single vectorized pass (no Python loops over atoms).
Hydrogen positions are fully determined by the parent heavy atom and its
bonded heavy-atom neighbours, so gradients flow from the VDW loss through
the H positions back to the heavy-atom coordinates via standard autograd.
"""
from typing import Dict, Optional
import numpy as np
import torch
from torch import nn
from torchref.config import dtypes, get_default_device
from torchref.utils.device_mixin import DeviceMixin
# ---------------------------------------------------------------------------
# Placement-type constants
# ---------------------------------------------------------------------------
ANTI_SUM = 0 # 1 H, >=2 heavy neighbours → opposite to sum of vectors
CH2_A = 1 # 2 H on 2-neighbour parent, slot 0 → +perp component
CH2_B = 2 # 2 H on 2-neighbour parent, slot 1 → -perp component
METHYL = 3 # 3 H on 1-neighbour parent → 120° staggered around axis
OPPOSITE_1NB = 4 # 1 H, 1 heavy neighbour → directly opposite
NH2_A = 5 # 2 H on 1-neighbour parent, slot 0
NH2_B = 6 # 2 H on 1-neighbour parent, slot 1
# Pre-computed tetrahedral geometry constants
_COS_TET = -1.0 / 3.0 # cos(180 - 109.47) from axis
_SIN_TET = np.sqrt(8.0 / 9.0) # sin(180 - 109.47)
_COS_120 = np.cos(2.0 * np.pi / 3.0) # -0.5
_SIN_120 = np.sin(2.0 * np.pi / 3.0) # √3/2
_COS_240 = np.cos(4.0 * np.pi / 3.0) # -0.5
_SIN_240 = np.sin(4.0 * np.pi / 3.0) # -√3/2
MAX_HEAVY_NB = 4 # Maximum heavy-atom neighbours to store per parent
# ---------------------------------------------------------------------------
# HydrogenTopology
# ---------------------------------------------------------------------------
[docs]
class HydrogenTopology(DeviceMixin, nn.Module):
"""Static topology describing riding hydrogens for VDW evaluation.
All data are stored as registered buffers so they move automatically
with ``.to(device)`` and appear in ``state_dict``.
Attributes
----------
h_parent_idx : (N_h,) long
Index into heavy-atom array for each riding H.
h_bond_length : (N_h,) float
Ideal H–parent bond length (Å).
h_vdw_radius : (N_h,) float
Van der Waals radius for each H (1.20 Å).
h_placement_type : (N_h,) long
Placement-geometry enum (see module-level constants).
h_slot_in_parent : (N_h,) long
Ordinal within sibling H atoms on the same parent (0, 1, 2).
parent_neighbor_idx : (N_h, MAX_HEAVY_NB) long
Heavy-atom neighbour indices of the parent (-1 = padding).
parent_neighbor_count : (N_h,) long
Actual number of heavy-atom neighbours for the parent.
h_chainid_enc : (N_h,) long
Encoded chain ID (for same-residue filtering).
h_resseq : (N_h,) long
Residue sequence number (for same-residue filtering).
"""
[docs]
def __init__(self):
super().__init__()
# Buffers are registered by build_hydrogen_topology()
@property
def n_hydrogens(self) -> int:
if hasattr(self, "h_parent_idx"):
return self.h_parent_idx.shape[0]
return 0
@property
def has_candidates(self) -> bool:
"""Whether precomputed H candidate pairs are available."""
return hasattr(self, "cand_idx_i") and self.cand_idx_i.shape[0] > 0
# ---------------------------------------------------------------------------
# Build-time topology construction
# ---------------------------------------------------------------------------
def _load_cif_hydrogen_info(pdb, verbose: int = 0) -> Dict:
"""Load per-residue-type H topology from the monomer library.
Re-uses the same CIF cache as ``Model.hydrogenate()``.
Returns
-------
cache : dict
``{resname: {ids, elems, coords, is_h, id_to_idx,
parent_map, ideal_bl, heavy_neighbor_map, ...} | None}``
"""
from torchref.restraints.library import MonomerLibraryManager
from torchref.model.model import Model
lib = MonomerLibraryManager(verbose=0)
cache = Model._hydrogenate_cif_cache
for rn in pdb["resname"].unique():
rn_str = str(rn).strip()
if not rn_str:
continue
if rn_str in cache:
if cache[rn_str] is None or "heavy_neighbor_map" in cache[rn_str]:
continue
del cache[rn_str]
cif_path = lib.get_cif_file(rn_str)
if cif_path is None:
cache[rn_str] = None
continue
try:
from torchref.io.cif_readers import RestraintCIFReader
import pandas as pd
reader = RestraintCIFReader(str(cif_path))
all_data = reader.get_all_restraints()
comp_data = all_data.get(rn_str) or all_data.get(rn_str.upper())
if comp_data is None:
cache[rn_str] = None
continue
atom_df = comp_data.get("atoms", comp_data.get("atom"))
bond_df = comp_data.get("bonds", comp_data.get("bond"))
if atom_df is None or atom_df.empty or "x" not in atom_df.columns:
cache[rn_str] = None
continue
except Exception:
cache[rn_str] = None
continue
ids = atom_df["atom_id"].astype(str).str.strip().values
elems = atom_df["type_symbol"].astype(str).str.strip().values
coords = atom_df[["x", "y", "z"]].values.astype(np.float64)
is_h = np.array([e.upper() == "H" for e in elems])
id_to_idx = {n: i for i, n in enumerate(ids)}
parent_map = {}
ideal_bl = {}
heavy_neighbor_map = {}
if bond_df is not None and not bond_df.empty:
a1s = bond_df["atom1"].astype(str).str.strip().values
a2s = bond_df["atom2"].astype(str).str.strip().values
vals = pd.to_numeric(bond_df["value"], errors="coerce").values
h_set = set(ids[is_h])
for i in range(len(a1s)):
b1, b2 = a1s[i], a2s[i]
if b1 in h_set and b2 in id_to_idx and not is_h[id_to_idx[b2]]:
parent_map[b1] = b2
if np.isfinite(vals[i]):
ideal_bl[b1] = float(vals[i])
elif b2 in h_set and b1 in id_to_idx and not is_h[id_to_idx[b1]]:
parent_map[b2] = b1
if np.isfinite(vals[i]):
ideal_bl[b2] = float(vals[i])
i1, i2 = id_to_idx.get(b1), id_to_idx.get(b2)
if (i1 is not None and i2 is not None
and not is_h[i1] and not is_h[i2]):
heavy_neighbor_map.setdefault(b1, []).append(b2)
heavy_neighbor_map.setdefault(b2, []).append(b1)
cache[rn_str] = {
"ids": ids, "elems": elems, "coords": coords,
"is_h": is_h, "id_to_idx": id_to_idx,
"heavy_names": ids[~is_h], "heavy_coords": coords[~is_h],
"h_names": ids[is_h], "h_coords": coords[is_h],
"parent_map": parent_map, "ideal_bl": ideal_bl,
"heavy_neighbor_map": heavy_neighbor_map,
}
return cache
def _classify_placement(n_h_on_parent: int, n_heavy_nb: int, slot: int) -> int:
"""Return placement-type code for an H atom, or -1 to skip.
Returns -1 when the parent has no heavy neighbours so geometry
cannot be determined.
"""
if n_heavy_nb == 0:
return -1 # cannot determine geometry — skip this H
if n_h_on_parent == 1:
if n_heavy_nb >= 2:
return ANTI_SUM
else:
return OPPOSITE_1NB
elif n_h_on_parent == 2:
if n_heavy_nb >= 2:
return CH2_A if slot == 0 else CH2_B
else:
return NH2_A if slot == 0 else NH2_B
elif n_h_on_parent == 3:
return METHYL
# Fallback for >3 H (rare)
return ANTI_SUM
[docs]
def build_hydrogen_topology(
pdb,
device: torch.device = get_default_device(),
verbose: int = 0,
) -> HydrogenTopology:
"""Build riding-hydrogen topology from the model's heavy-atom DataFrame.
Parameters
----------
pdb : pd.DataFrame
Heavy-atom DataFrame (``strip_H=True``).
device : torch.device
Target device for tensors.
verbose : int
Verbosity level.
Returns
-------
HydrogenTopology
Module with registered buffer tensors.
"""
cache = _load_cif_hydrogen_info(pdb, verbose)
model_names = pdb["name"].astype(str).str.strip().values
model_xyz = pdb[["x", "y", "z"]].values.astype(np.float64)
model_elem = pdb["element"].astype(str).str.strip().values
model_heavy_mask = np.array([e.upper() != "H" for e in model_elem])
# Encode chain IDs as integers for fast same-residue comparison
chain_vals = pdb["chainid"].values.astype(str)
unique_chains = np.unique(chain_vals)
chain_to_int = {c: i for i, c in enumerate(unique_chains)}
model_chainid_enc = np.array([chain_to_int[c] for c in chain_vals], dtype=np.int64)
model_resseq = pdb["resseq"].values.astype(np.int64)
# Group residues
group_cols = ["chainid", "resseq", "icode", "resname"]
group_keys = pdb[group_cols].values
changes = np.zeros(len(group_keys), dtype=bool)
changes[0] = True
for c in range(4):
changes[1:] |= group_keys[1:, c] != group_keys[:-1, c]
group_starts = np.nonzero(changes)[0]
group_ends = np.append(group_starts[1:], len(group_keys))
# Standard valence for expected-H-count capping
_std_val = {"C": 4, "N": 3, "O": 2, "S": 2}
# Accumulators
acc_parent_idx = []
acc_bond_length = []
acc_placement_type = []
acc_slot = []
acc_nb_idx = [] # list of (MAX_HEAVY_NB,) arrays
acc_nb_count = []
acc_chainid_enc = []
acc_resseq = []
for gi in range(len(group_starts)):
s, e = group_starts[gi], group_ends[gi]
rn = str(group_keys[s, 3]).strip()
info = cache.get(rn)
if info is None:
continue
chainid_enc = model_chainid_enc[s]
resseq = model_resseq[s]
names_in_model = set(model_names[s:e])
h_to_add_mask = np.array(
[n not in names_in_model for n in info["h_names"]], dtype=bool
)
if not h_to_add_mask.any():
continue
h_names_add = info["h_names"][h_to_add_mask]
# Build name→global-index map for this residue
name_to_global = {}
for j in range(s, e):
nm = model_names[j]
if nm not in name_to_global:
name_to_global[nm] = j
# Group H atoms by parent
parent_to_h = {}
for h_name in h_names_add:
pn = info["parent_map"].get(h_name)
if pn is not None and pn in name_to_global:
parent_to_h.setdefault(pn, []).append(h_name)
hnm = info.get("heavy_neighbor_map", {})
id2i = info["id_to_idx"]
for par_name, h_list in parent_to_h.items():
pidx = name_to_global[par_name]
# Find heavy-atom neighbours of parent via distance
dvec = model_xyz - model_xyz[pidx]
dists_sq = (dvec ** 2).sum(1)
bonded = np.where(
(dists_sq > 0.09) & (dists_sq < 3.61) & model_heavy_mask
)[0]
bonded = bonded[bonded != pidx]
n_model_heavy = len(bonded)
# Cap H count by expected valence
par_elem = info["elems"][id2i[par_name]].upper()
expected_h = max(0, _std_val.get(par_elem, 4) - n_model_heavy)
h_list_capped = sorted(h_list)[:expected_h]
if not h_list_capped:
continue
n_h = len(h_list_capped)
# Neighbour index array (padded)
nb_arr = np.full(MAX_HEAVY_NB, -1, dtype=np.int64)
nb_count = min(n_model_heavy, MAX_HEAVY_NB)
nb_arr[:nb_count] = bonded[:nb_count]
for slot, h_name in enumerate(h_list_capped):
bl = info["ideal_bl"].get(h_name, 0.97)
ptype = _classify_placement(n_h, n_model_heavy, slot)
if ptype < 0:
continue # skip — cannot determine geometry
acc_parent_idx.append(pidx)
acc_bond_length.append(bl)
acc_placement_type.append(ptype)
acc_slot.append(slot)
acc_nb_idx.append(nb_arr.copy())
acc_nb_count.append(nb_count)
acc_chainid_enc.append(chainid_enc)
acc_resseq.append(resseq)
topo = HydrogenTopology()
n_h_total = len(acc_parent_idx)
fdtype = dtypes.float
if n_h_total == 0:
topo.register_buffer(
"h_parent_idx", torch.zeros(0, dtype=torch.long, device=device)
)
topo.register_buffer(
"h_bond_length", torch.zeros(0, dtype=fdtype, device=device)
)
topo.register_buffer(
"h_vdw_radius", torch.zeros(0, dtype=fdtype, device=device)
)
topo.register_buffer(
"h_placement_type", torch.zeros(0, dtype=torch.long, device=device)
)
topo.register_buffer(
"h_slot_in_parent", torch.zeros(0, dtype=torch.long, device=device)
)
topo.register_buffer(
"parent_neighbor_idx",
torch.zeros(0, MAX_HEAVY_NB, dtype=torch.long, device=device),
)
topo.register_buffer(
"parent_neighbor_count", torch.zeros(0, dtype=torch.long, device=device)
)
topo.register_buffer(
"h_chainid_enc", torch.zeros(0, dtype=torch.long, device=device)
)
topo.register_buffer(
"h_resseq", torch.zeros(0, dtype=torch.long, device=device)
)
return topo
# Sort all topology arrays by placement type for contiguous slicing
ptype_arr = np.array(acc_placement_type, dtype=np.int64)
sort_order = np.argsort(ptype_arr, kind="stable")
acc_parent_idx = [acc_parent_idx[i] for i in sort_order]
acc_bond_length = [acc_bond_length[i] for i in sort_order]
acc_placement_type = [acc_placement_type[i] for i in sort_order]
acc_slot = [acc_slot[i] for i in sort_order]
acc_nb_idx = [acc_nb_idx[i] for i in sort_order]
acc_nb_count = [acc_nb_count[i] for i in sort_order]
acc_chainid_enc = [acc_chainid_enc[i] for i in sort_order]
acc_resseq = [acc_resseq[i] for i in sort_order]
# Compute type boundaries: type_bounds[t] = (start, end) slice
sorted_types = np.array(acc_placement_type, dtype=np.int64)
type_bounds = {}
for t in range(7):
mask = sorted_types == t
if mask.any():
idxs = np.where(mask)[0]
type_bounds[t] = (int(idxs[0]), int(idxs[-1]) + 1)
topo.register_buffer(
"h_parent_idx",
torch.tensor(acc_parent_idx, dtype=torch.long, device=device),
)
topo.register_buffer(
"h_bond_length",
torch.tensor(acc_bond_length, dtype=fdtype, device=device),
)
topo.register_buffer(
"h_vdw_radius",
torch.full((n_h_total,), 1.20, dtype=fdtype, device=device),
)
topo.register_buffer(
"h_placement_type",
torch.tensor(acc_placement_type, dtype=torch.long, device=device),
)
topo.register_buffer(
"h_slot_in_parent",
torch.tensor(acc_slot, dtype=torch.long, device=device),
)
topo.register_buffer(
"parent_neighbor_idx",
torch.tensor(np.stack(acc_nb_idx), dtype=torch.long, device=device),
)
topo.register_buffer(
"parent_neighbor_count",
torch.tensor(acc_nb_count, dtype=torch.long, device=device),
)
topo.type_bounds = type_bounds # dict: type_code -> (start, end)
topo.register_buffer(
"h_chainid_enc",
torch.tensor(acc_chainid_enc, dtype=torch.long, device=device),
)
topo.register_buffer(
"h_resseq",
torch.tensor(acc_resseq, dtype=torch.long, device=device),
)
if verbose > 0:
print(f" Hydrogen topology: {n_h_total} riding H atoms")
return topo
# ---------------------------------------------------------------------------
# Vectorized H placement (forward-time, differentiable)
# ---------------------------------------------------------------------------
def _safe_normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Normalize vectors along last dimension with epsilon for stability."""
return v / (v.norm(dim=-1, keepdim=True) + eps)
def _orthonormal_basis(axis: torch.Tensor) -> tuple:
"""Build two perpendicular unit vectors for each axis vector.
Parameters
----------
axis : (B, 3) unit vectors
Returns
-------
perp1, perp2 : each (B, 3) unit vectors forming a right-handed frame
"""
B = axis.shape[0]
# Choose the cardinal axis least aligned with the input axis
abs_ax = axis.abs()
# Least component → use that cardinal direction
min_idx = abs_ax.argmin(dim=-1) # (B,)
cardinal = torch.zeros_like(axis)
cardinal[torch.arange(B, device=axis.device), min_idx] = 1.0
perp1 = torch.cross(axis, cardinal, dim=-1)
perp1 = _safe_normalize(perp1)
perp2 = torch.cross(axis, perp1, dim=-1)
return perp1, perp2
def _precompute_direction_coefficients(topo: HydrogenTopology) -> torch.Tensor:
"""Precompute (c_base, c_perp1, c_perp2) per H atom.
Every riding H direction is ``c0*base + c1*perp1 + c2*perp2``
where (base, perp1, perp2) is a frame built from neighbour vectors.
Coefficients depend only on placement type — constant across steps.
"""
device = topo.h_placement_type.device
fdtype = topo.h_bond_length.dtype
N_h = topo.h_parent_idx.shape[0]
coeffs = torch.zeros(N_h, 3, dtype=fdtype, device=device)
a_ch2 = 1.0 / np.sqrt(3.0)
b_ch2 = np.sqrt(2.0 / 3.0)
tb = getattr(topo, 'type_bounds', None)
if tb is None:
return coeffs
slot = topo.h_slot_in_parent
for code in range(7):
if code not in tb:
continue
s, e = tb[code]
if code == ANTI_SUM or code == OPPOSITE_1NB:
coeffs[s:e, 0] = 1.0
elif code == CH2_A:
coeffs[s:e, 0] = a_ch2
coeffs[s:e, 1] = b_ch2
elif code == CH2_B:
coeffs[s:e, 0] = a_ch2
coeffs[s:e, 1] = -b_ch2
elif code == METHYL:
angle = slot[s:e].to(fdtype) * (2.0 * np.pi / 3.0)
coeffs[s:e, 0] = _COS_TET
coeffs[s:e, 1] = _SIN_TET * torch.cos(angle)
coeffs[s:e, 2] = _SIN_TET * torch.sin(angle)
elif code == NH2_A:
coeffs[s:e, 0] = 0.5
coeffs[s:e, 1] = np.sqrt(3.0) / 2.0
elif code == NH2_B:
coeffs[s:e, 0] = 0.5
coeffs[s:e, 1] = -np.sqrt(3.0) / 2.0
return coeffs
@torch.jit.script
def _place_h_jit(
xyz_heavy: torch.Tensor,
h_parent_idx: torch.Tensor,
nb_idx_clamped: torch.Tensor,
nb_valid: torch.Tensor,
coeffs: torch.Tensor,
bond_length: torch.Tensor,
) -> torch.Tensor:
"""JIT-compiled H placement kernel — fuses ~230 ops into few GPU kernels.
Parameters
----------
xyz_heavy : (N_heavy, 3)
h_parent_idx : (N_h,) long
nb_idx_clamped : (N_h, 4) long — neighbour indices (clamped, -1 → 0)
nb_valid : (N_h, 4, 1) float — 1.0 where neighbour is valid, 0.0 where pad
coeffs : (N_h, 3) float — [c_base, c_perp1, c_perp2]
bond_length : (N_h, 1) float
Returns
-------
(N_h, 3) H positions
"""
eps = 1e-8
N_h = h_parent_idx.shape[0]
# Gather
parent_pos = xyz_heavy[h_parent_idx] # (N_h, 3)
nb_pos = xyz_heavy[nb_idx_clamped] # (N_h, 4, 3)
# Neighbour vectors (masked)
vecs = (nb_pos - parent_pos.unsqueeze(1)) * nb_valid
# Base direction: -normalized sum of neighbour vectors
neg_sum = -(vecs.sum(dim=1))
base = neg_sum / (torch.norm(neg_sum, 2, -1, True) + eps)
# Perpendicular frame from cross(v1, v2)
v1 = vecs[:, 0, :]
v2 = vecs[:, 1, :]
cross12 = torch.linalg.cross(v1, v2)
cross_norm = torch.norm(cross12, 2, -1, True)
has_cross = cross_norm > 1e-6
perp1_cross = cross12 / (cross_norm + eps)
# Fallback for 1-neighbour atoms: pick least-aligned cardinal axis
abs_base = base.abs()
min_idx = abs_base.argmin(dim=-1)
cardinal = torch.zeros_like(base)
cardinal.scatter_(1, min_idx.unsqueeze(1), 1.0)
perp1_ortho_raw = torch.linalg.cross(base, cardinal)
perp1_ortho = perp1_ortho_raw / (torch.norm(perp1_ortho_raw, 2, -1, True) + eps)
perp1 = torch.where(has_cross, perp1_cross, perp1_ortho)
perp2 = torch.linalg.cross(base, perp1)
# Direction = c0*base + c1*perp1 + c2*perp2
direction = coeffs[:, 0:1] * base + coeffs[:, 1:2] * perp1 + coeffs[:, 2:3] * perp2
return parent_pos + bond_length * direction
[docs]
def place_riding_hydrogens(
xyz_heavy: torch.Tensor,
topo: HydrogenTopology,
) -> torch.Tensor:
"""Generate riding hydrogen positions from heavy-atom coordinates.
Delegates to a JIT-compiled kernel that fuses element-wise ops,
reducing GPU kernel launches from ~230 to ~30.
Parameters
----------
xyz_heavy : (N_heavy, 3) float tensor (requires_grad typically True)
topo : HydrogenTopology
Returns
-------
xyz_h : (N_h, 3) float tensor, differentiable w.r.t. xyz_heavy
"""
N_h = topo.h_parent_idx.shape[0]
if N_h == 0:
return torch.zeros(0, 3, dtype=xyz_heavy.dtype, device=xyz_heavy.device)
# Precompute direction coefficients on first call
if not hasattr(topo, '_dir_coeffs') or topo._dir_coeffs is None:
topo._dir_coeffs = _precompute_direction_coefficients(topo)
# Precompute static tensors on first call (avoid recomputing every step)
if not hasattr(topo, '_nb_idx_clamped'):
topo._nb_idx_clamped = topo.parent_neighbor_idx.clamp(min=0)
topo._nb_valid = (
(topo.parent_neighbor_idx >= 0)
.unsqueeze(-1)
.to(topo.h_bond_length.dtype)
)
topo._bond_len_col = topo.h_bond_length.unsqueeze(-1)
# On CUDA fp32 use the fused Triton forward + analytic Triton
# backward (saves ~1.5–1.9 ms on the H-VDW backward — that path is
# where ~40 % of the non-bonded fwd+bw cost lives). Falls back to
# the JIT-scripted eager helper otherwise.
if xyz_heavy.is_cuda and xyz_heavy.dtype == torch.float32:
try:
from torchref.base.targets.triton.place_hydrogens import (
place_riding_hydrogens_triton,
)
return place_riding_hydrogens_triton(
xyz_heavy,
topo.h_parent_idx,
topo._nb_idx_clamped,
topo._nb_valid,
topo._dir_coeffs,
topo._bond_len_col,
)
except ImportError:
pass
return _place_h_jit(
xyz_heavy,
topo.h_parent_idx,
topo._nb_idx_clamped,
topo._nb_valid,
topo._dir_coeffs,
topo._bond_len_col,
)
# ---------------------------------------------------------------------------
# Build-time H candidate pair precomputation
# ---------------------------------------------------------------------------
[docs]
def build_h_candidate_pairs(
h_topo: HydrogenTopology,
vdw_data: dict,
pdb,
h_excl_hash: torch.Tensor,
device: torch.device = get_default_device(),
verbose: int = 0,
) -> None:
"""Precompute candidate H-involving VDW pairs from heavy-atom pair list.
For each heavy-heavy VDW pair (A, B, symop, offset), derives candidate
H-heavy pairs where H rides on A and could interact with B (or vice
versa). Applies exclusion and same-residue filters at build time so
that the forward pass only needs to compute distances and energy.
Results are stored as registered buffers on ``h_topo``:
* ``cand_idx_i`` (C,) long — first atom (combined index)
* ``cand_idx_j`` (C,) long — second atom (combined index)
* ``cand_symop_idx`` (C,) long — symop index for the heavy atom
* ``cand_cell_offset`` (C, 3) long — cell translation for the heavy atom
* ``cand_min_dist`` (C,) float — VDW radius sum (H + heavy)
Parameters
----------
h_topo : HydrogenTopology
vdw_data : dict
Output of ``build_vdw_restraints_gpu`` (keys: indices, symop_indices,
cell_offsets, etc.).
pdb : DataFrame
Heavy-atom DataFrame.
h_excl_hash : (E,) long
Sorted exclusion hash tensor for H-specific 1-2/1-3 pairs.
device : torch.device
verbose : int
"""
n_h = h_topo.n_hydrogens
n_heavy = len(pdb)
if n_h == 0:
for name in ("cand_idx_i", "cand_idx_j", "cand_symop_idx"):
h_topo.register_buffer(name, torch.zeros(0, dtype=torch.long, device=device))
h_topo.register_buffer(
"cand_cell_offset", torch.zeros(0, 3, dtype=torch.long, device=device)
)
h_topo.register_buffer(
"cand_min_dist", torch.zeros(0, dtype=dtypes.float, device=device)
)
return
heavy_indices = vdw_data["indices"] # (P, 2)
heavy_symop = vdw_data["symop_indices"] # (P,)
heavy_offsets = vdw_data["cell_offsets"] # (P, 3)
parent_idx_np = h_topo.h_parent_idx.cpu().numpy() # (N_h,)
h_vdw_np = h_topo.h_vdw_radius.cpu().numpy() # (N_h,)
h_chain_np = h_topo.h_chainid_enc.cpu().numpy() # (N_h,)
h_resseq_np = h_topo.h_resseq.cpu().numpy() # (N_h,)
# Build parent → H index mapping
parent_to_h = {}
for hi in range(n_h):
p = int(parent_idx_np[hi])
parent_to_h.setdefault(p, []).append(hi)
# Heavy atom chain/resseq for same-residue filter
chain_vals = pdb["chainid"].values.astype(str)
unique_chains = np.unique(chain_vals)
chain_to_int = {c: i for i, c in enumerate(unique_chains)}
heavy_chain_np = np.array([chain_to_int.get(c, -1) for c in chain_vals], dtype=np.int64)
heavy_resseq_np = pdb["resseq"].values.astype(np.int64)
idx_A = heavy_indices[:, 0].cpu().numpy()
idx_B = heavy_indices[:, 1].cpu().numpy()
symop_np = heavy_symop.cpu().numpy()
offsets_np = heavy_offsets.cpu().numpy()
# Read VDW radii from the min_distances and indices of existing pairs
# Instead, read from data file — we need per-atom VDW radii
# Use h_vdw_radius for H, and derive heavy VDW radii from vdw_data
# Since we have min_distances = radii[A] + radii[B], we can't easily
# decompose. Instead, just load the model's VDW radii.
# The caller should pass these. For now, compute from min_distances
# and use a simple approach: store radii per heavy atom.
# Actually, just use a lookup: min_dist_h_heavy = H_vdw + heavy_vdw
# We need heavy_vdw per atom. Derive from existing pair data:
# For any pair (A, B): min_dist = radii[A] + radii[B]
# We can solve if we have a self-pair, but we don't. Instead, pass
# heavy_vdw_radii directly from the restraints object.
# For now, use a default approach: 1.7 Å for heavy atoms and refine later.
# This will be overridden by the caller.
# Candidate pairs stored as indices into the combined array:
# [0 .. n_heavy-1] = heavy atoms, [n_heavy .. n_heavy+n_h-1] = H atoms
# This way both H-heavy AND H-H pairs use the same format.
acc_idx_i = [] # ASU atom (combined index)
acc_idx_j = [] # partner atom (combined index, may need symop)
acc_symop = []
acc_offset = []
def _same_res(chain_a, resseq_a, chain_b, resseq_b):
return chain_a == chain_b and resseq_a == resseq_b
for p_idx in range(len(idx_A)):
A, B = int(idx_A[p_idx]), int(idx_B[p_idx])
sym = int(symop_np[p_idx])
off = offsets_np[p_idx]
is_intra_asu = (sym == 0) and (off == 0).all()
h_on_A = parent_to_h.get(A, [])
h_on_B = parent_to_h.get(B, [])
# --- H on A ↔ heavy B ---
for hi in h_on_A:
if is_intra_asu and _same_res(
h_chain_np[hi], h_resseq_np[hi],
heavy_chain_np[B], heavy_resseq_np[B]):
continue
acc_idx_i.append(n_heavy + hi)
acc_idx_j.append(B)
acc_symop.append(sym)
acc_offset.append(off)
# --- H on B ↔ heavy A ---
for hi in h_on_B:
if is_intra_asu and _same_res(
h_chain_np[hi], h_resseq_np[hi],
heavy_chain_np[A], heavy_resseq_np[A]):
continue
acc_idx_i.append(n_heavy + hi)
acc_idx_j.append(A)
acc_symop.append(0)
acc_offset.append(np.zeros(3, dtype=np.int64))
# --- H on A ↔ H on B (H-H contacts) ---
for hi_a in h_on_A:
for hi_b in h_on_B:
if is_intra_asu and _same_res(
h_chain_np[hi_a], h_resseq_np[hi_a],
h_chain_np[hi_b], h_resseq_np[hi_b]):
continue
# For intra-ASU, only keep hi_a < hi_b to avoid double-counting
if is_intra_asu and hi_a >= hi_b:
continue
acc_idx_i.append(n_heavy + hi_a)
acc_idx_j.append(n_heavy + hi_b)
acc_symop.append(sym)
acc_offset.append(off)
if not acc_idx_i:
for name in ("cand_idx_i", "cand_idx_j", "cand_symop_idx"):
h_topo.register_buffer(name, torch.zeros(0, dtype=torch.long, device=device))
h_topo.register_buffer(
"cand_cell_offset", torch.zeros(0, 3, dtype=torch.long, device=device)
)
h_topo.register_buffer(
"cand_min_dist", torch.zeros(0, dtype=dtypes.float, device=device)
)
return
cand_i = torch.tensor(acc_idx_i, dtype=torch.long, device=device)
cand_j = torch.tensor(acc_idx_j, dtype=torch.long, device=device)
cand_sym = torch.tensor(acc_symop, dtype=torch.long, device=device)
cand_off = torch.tensor(np.stack(acc_offset), dtype=torch.long, device=device)
# Apply 1-2 / 1-3 exclusions for intra-ASU candidates
if h_excl_hash is not None and len(h_excl_hash) > 0:
is_intra = (cand_sym == 0) & (cand_off == 0).all(dim=1)
if is_intra.any():
max_idx = n_heavy + n_h
norm_i = torch.minimum(cand_i, cand_j)
norm_j = torch.maximum(cand_i, cand_j)
pair_hash = norm_i * max_idx + norm_j
ins = torch.searchsorted(h_excl_hash, pair_hash).clamp(max=len(h_excl_hash) - 1)
is_excluded = (h_excl_hash[ins] == pair_hash) & is_intra
keep = ~is_excluded
cand_i = cand_i[keep]
cand_j = cand_j[keep]
cand_sym = cand_sym[keep]
cand_off = cand_off[keep]
# Deduplicate
if len(cand_i) > 0:
n_all = n_heavy + n_h
dedup_key = (
cand_i.long() * (n_all * 1000)
+ cand_j.long() * 1000
+ cand_sym.long() * 27
+ (cand_off[:, 0] + 1) * 9
+ (cand_off[:, 1] + 1) * 3
+ (cand_off[:, 2] + 1)
)
_, first_idx = torch.unique(dedup_key, return_inverse=True)
# MPS does not support int64 scatter_reduce; use configured int dtype.
_int_dtype = dtypes.int
first_idx_i = first_idx.to(_int_dtype)
perm = torch.arange(len(cand_i), device=device, dtype=_int_dtype)
n_unique = first_idx.max().item() + 1
first_occ = torch.full((n_unique,), len(cand_i), dtype=_int_dtype, device=device)
first_occ.scatter_reduce_(0, first_idx_i, perm, reduce="amin")
mask = torch.zeros(len(cand_i), dtype=torch.bool, device=device)
mask[first_occ.long()] = True
cand_i = cand_i[mask]
cand_j = cand_j[mask]
cand_sym = cand_sym[mask]
cand_off = cand_off[mask]
# Sort: ASU candidates first, symmetry last
is_asu = (cand_sym == 0) & (cand_off == 0).all(dim=1)
sort_order = (~is_asu).long().argsort(stable=True)
cand_i = cand_i[sort_order]
cand_j = cand_j[sort_order]
cand_sym = cand_sym[sort_order]
cand_off = cand_off[sort_order]
n_asu_cand = is_asu.sum().item()
h_topo.register_buffer("cand_idx_i", cand_i)
h_topo.register_buffer("cand_idx_j", cand_j)
h_topo.register_buffer("cand_symop_idx", cand_sym)
h_topo.register_buffer("cand_cell_offset", cand_off)
h_topo.n_asu_candidates = n_asu_cand
h_topo.register_buffer(
"cand_min_dist",
torch.zeros(len(cand_i), dtype=dtypes.float, device=device),
)
if verbose > 0:
n_hh = ((cand_i >= n_heavy) & (cand_j >= n_heavy)).sum().item()
n_sym = ((cand_sym != 0) | (cand_off != 0).any(dim=1)).sum().item()
print(f" H candidate pairs: {len(cand_i)} "
f"({n_hh} H-H, {len(cand_i)-n_hh} H-heavy, {n_sym} symmetry)")