"""
GPU-native periodic neighbor search for VDW restraints.
Works in fractional space with periodic boundary conditions.
Avoids explicit symmetry expansion by assigning (atom, symop+offset)
entries to grid cells and using padded batched ``torch.cdist``
for distance computation.
All operations run under ``torch.no_grad()`` on whatever device
the input coordinates live on (CPU or GPU).
"""
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import numpy as np
import torch
from torchref.config import dtypes, get_float_dtype
if TYPE_CHECKING:
from torchref.symmetry.cell import Cell
from torchref.symmetry.spacegroup import SpaceGroup
# ------------------------------------------------------------------ #
# Step 1 – centroid pre-filter
# ------------------------------------------------------------------ #
[docs]
def prefilter_symop_offsets(
cell: "Cell",
sg: "SpaceGroup",
xyz_frac: torch.Tensor,
cutoff: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Select (symop, cell_offset) combos that could produce contacts.
Uses the ASU centroid and molecule radius to eliminate obviously
distant combinations. Always includes identity (op=0, offset=0).
Parameters
----------
cell, sg : Cell, SpaceGroup
xyz_frac : (N, 3) fractional ASU coordinates
cutoff : float Cartesian cutoff in Angstrom
Returns
-------
op_indices : (M,) long – symop indices for each valid combo
cell_offsets : (M, 3) long – integer cell translations
"""
device = xyz_frac.device
fdtype = dtypes.float
centroid_frac = xyz_frac.mean(dim=0)
centroid_cart = cell.fractional_to_cartesian(xyz_frac).mean(dim=0)
xyz_cart = cell.fractional_to_cartesian(xyz_frac)
molecule_radius = (xyz_cart - centroid_cart).norm(dim=1).max().item()
threshold = 2.0 * molecule_radius + cutoff
B = cell.fractional_matrix.to(device=device, dtype=fdtype)
I_mat = torch.eye(3, dtype=fdtype, device=device)
matrices = sg.matrices.to(device=device, dtype=fdtype)
translations = sg.translations.to(device=device, dtype=fdtype)
valid_ops = []
valid_offsets = []
for op_idx in range(sg.n_ops):
R = matrices[op_idx]
t = translations[op_idx]
for dx in range(-1, 2):
for dy in range(-1, 2):
for dz in range(-1, 2):
offset = torch.tensor([dx, dy, dz], dtype=fdtype,
device=device)
d_frac = (R - I_mat) @ centroid_frac + t + offset
d_cart = B @ d_frac
if d_cart.norm().item() <= threshold:
valid_ops.append(op_idx)
valid_offsets.append([dx, dy, dz])
op_indices = torch.tensor(valid_ops, dtype=torch.long, device=device)
cell_offsets = torch.tensor(valid_offsets, dtype=torch.long, device=device)
return op_indices, cell_offsets
# ------------------------------------------------------------------ #
# Step 2 – vectorised image positions + grid assignment
# ------------------------------------------------------------------ #
[docs]
def assign_to_grid(
xyz_frac: torch.Tensor,
cell: "Cell",
sg: "SpaceGroup",
op_indices: torch.Tensor,
cell_offsets: torch.Tensor,
grid_dims: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute Cartesian image positions and assign to grid cells.
Parameters
----------
xyz_frac : (N, 3)
cell : Cell
sg : SpaceGroup
op_indices : (M,) long
cell_offsets : (M, 3) long
grid_dims : (3,) long – number of grid cells per axis
Returns
-------
flat_cell : (N*M,) long – flat grid cell index per entry
atom_idx : (N*M,) long – ASU atom index
combo_idx : (N*M,) long – index into op_indices / cell_offsets
cart_pos : (N*M, 3) float – Cartesian positions (reused in step 4)
"""
device = xyz_frac.device
fdtype = dtypes.float
N = xyz_frac.shape[0]
M = op_indices.shape[0]
R_sel = sg.matrices[op_indices].to(dtype=fdtype) # (M, 3, 3)
t_sel = sg.translations[op_indices].to(dtype=fdtype) # (M, 3)
offs = cell_offsets.to(dtype=fdtype) # (M, 3)
# (N, M, 3) = einsum over symops applied to each atom
frac_images = (
torch.einsum("mij,nj->nmi", R_sel, xyz_frac.to(fdtype))
+ t_sel[None, :, :]
+ offs[None, :, :]
)
# Cartesian positions (stored for reuse)
cart_pos = cell.fractional_to_cartesian(
frac_images.reshape(-1, 3)
) # (N*M, 3)
# Wrap to [0, 1) for grid assignment
frac_wrapped = frac_images % 1.0
gd = grid_dims.to(device=device, dtype=fdtype)
cell_ijk = (frac_wrapped * gd[None, None, :]).long()
cell_ijk = cell_ijk.clamp(
min=torch.zeros(3, dtype=torch.long, device=device),
max=(grid_dims - 1).to(device),
)
gy, gz = grid_dims[1].item(), grid_dims[2].item()
flat_cell = (
cell_ijk[..., 0] * (gy * gz)
+ cell_ijk[..., 1] * gz
+ cell_ijk[..., 2]
).reshape(-1) # (N*M,)
atom_idx = torch.arange(N, device=device).unsqueeze(1).expand(N, M).reshape(-1)
combo_idx = torch.arange(M, device=device).unsqueeze(0).expand(N, M).reshape(-1)
return flat_cell, atom_idx, combo_idx, cart_pos
# ------------------------------------------------------------------ #
# Step 3 – sort into cell list (CSR)
# ------------------------------------------------------------------ #
[docs]
def build_cell_list(
flat_cell: torch.Tensor,
n_grid_total: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sort entries by grid cell and build CSR boundary arrays.
Returns
-------
sort_order : (E,) long
unique_cells : (C,) long – occupied cell indices
starts : (C+1,) long – CSR boundaries into sorted arrays
cell_lookup : (n_grid_total,) long – maps flat cell → index in
unique_cells, or -1 if empty.
"""
device = flat_cell.device
sort_order = flat_cell.argsort()
sorted_cells = flat_cell[sort_order]
unique_cells, counts = torch.unique_consecutive(
sorted_cells, return_counts=True
)
starts = torch.zeros(len(unique_cells) + 1, dtype=torch.long, device=device)
starts[1:] = counts.cumsum(0)
cell_lookup = torch.full(
(n_grid_total,), -1, dtype=torch.long, device=device
)
cell_lookup[unique_cells] = torch.arange(
len(unique_cells), dtype=torch.long, device=device
)
return sort_order, unique_cells, starts, cell_lookup
# ------------------------------------------------------------------ #
# Step 4 – padded batched cdist over 27 neighbor offsets
# ------------------------------------------------------------------ #
_NEIGHBOR_OFFSETS_27 = None
_NEIGHBOR_OFFSETS_14 = None
def _get_neighbor_offsets(device: torch.device) -> torch.Tensor:
"""Return (27, 3) tensor of all neighbor offsets including self."""
global _NEIGHBOR_OFFSETS_27
if _NEIGHBOR_OFFSETS_27 is None or _NEIGHBOR_OFFSETS_27.device != device:
offsets = []
for dx in range(-1, 2):
for dy in range(-1, 2):
for dz in range(-1, 2):
offsets.append([dx, dy, dz])
_NEIGHBOR_OFFSETS_27 = torch.tensor(offsets, dtype=torch.long, device=device)
return _NEIGHBOR_OFFSETS_27
def _get_canonical_offsets_14(device: torch.device) -> torch.Tensor:
"""Return (14, 3) lex-positive half of the 27-offset cube, including self.
For each mirror pair ``(d, -d)`` in the 26 non-zero offsets, keep the
direction whose first non-zero component is positive. Plus the (0,0,0)
self-offset. Processing only this half covers every unique cell pair
exactly once, with the symmetric partner reached by swapping source
and target indices. Net cdist work is ~half of the 27-offset path.
"""
global _NEIGHBOR_OFFSETS_14
if _NEIGHBOR_OFFSETS_14 is None or _NEIGHBOR_OFFSETS_14.device != device:
offsets = [[0, 0, 0]]
for dx in range(-1, 2):
for dy in range(-1, 2):
for dz in range(-1, 2):
if (dx, dy, dz) == (0, 0, 0):
continue
# Lex-positive: first non-zero component is positive.
if dx > 0:
offsets.append([dx, dy, dz])
elif dx == 0 and dy > 0:
offsets.append([dx, dy, dz])
elif dx == 0 and dy == 0 and dz > 0:
offsets.append([dx, dy, dz])
assert len(offsets) == 14, f"expected 14 canonical offsets, got {len(offsets)}"
_NEIGHBOR_OFFSETS_14 = torch.tensor(
offsets, dtype=torch.long, device=device
)
return _NEIGHBOR_OFFSETS_14
def _build_padded_cells(
cart_sorted: torch.Tensor,
starts: torch.Tensor,
atom_idx_sorted: torch.Tensor,
combo_idx_sorted: torch.Tensor,
identity_combo: int,
max_per_cell: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build one padded ``(C, max_per_cell, 3)`` tensor + per-slot masks.
Precomputes everything we need for the cdist loop: positions padded
to a fixed row length, a validity mask (``True`` for real atoms,
``False`` for padding), and an ASU-membership mask (``True`` for
entries whose combo is the identity).
Returns
-------
padded_xyz : (C, max_per_cell, 3) float, padding filled with ``inf``
valid_mask : (C, max_per_cell) bool, True for real entries
asu_mask : (C, max_per_cell) bool, True for identity-combo entries
"""
device = cart_sorted.device
C = starts.shape[0] - 1
cell_starts = starts[:-1] # (C,)
counts = starts[1:] - cell_starts # (C,)
# (max_per_cell,) running index inside each row
col = torch.arange(max_per_cell, device=device)
# (C, max_per_cell) boolean mask of valid entries.
valid_mask = col.unsqueeze(0) < counts.unsqueeze(1)
# (C, max_per_cell) global index into the sorted CSR arrays, safe for
# padding slots (clamped to last valid entry; those slots are masked out).
gidx = cell_starts.unsqueeze(1) + col.unsqueeze(0)
gidx = gidx.clamp(max=cart_sorted.shape[0] - 1)
padded_xyz = torch.full(
(C, max_per_cell, 3),
float("inf"),
dtype=cart_sorted.dtype,
device=device,
)
padded_xyz[valid_mask] = cart_sorted[gidx[valid_mask]]
# ASU mask: True iff entry is a real one AND its combo is identity.
is_asu_entry = combo_idx_sorted == identity_combo
asu_gather = torch.zeros_like(valid_mask)
asu_gather[valid_mask] = is_asu_entry[gidx[valid_mask]]
return padded_xyz, valid_mask, asu_gather
def _gather_padded(
cart_sorted: torch.Tensor,
starts: torch.Tensor,
cell_indices: torch.Tensor,
max_per_cell: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gather Cartesian positions into a padded (batch, max_per_cell, 3) tensor.
Parameters
----------
cart_sorted : (E, 3) Cartesian positions sorted by cell
starts : (C+1,) CSR boundaries
cell_indices : (B,) which occupied-cell indices to gather
max_per_cell : padding size
Returns
-------
padded : (B, max_per_cell, 3) – padded with inf
counts : (B,) – actual atom count per cell
"""
device = cart_sorted.device
B = cell_indices.shape[0]
padded = torch.full(
(B, max_per_cell, 3), float("inf"), dtype=cart_sorted.dtype, device=device
)
cell_starts = starts[cell_indices]
cell_ends = starts[cell_indices + 1]
counts = cell_ends - cell_starts
# Vectorised fill: build flat indices for all entries across all cells
# within_cell_pos[k] = position within its cell (0, 1, 2, ...)
max_count = counts.max().item()
arange = torch.arange(max_count, device=device)
# (B, max_count) mask of valid positions
valid = arange.unsqueeze(0) < counts.unsqueeze(1)
# Global source indices into cart_sorted
src_idx = cell_starts.unsqueeze(1) + arange.unsqueeze(0) # (B, max_count)
src_idx = src_idx.clamp(max=len(cart_sorted) - 1)
# Scatter into padded tensor
padded[:, :max_count, :][valid] = cart_sorted[src_idx[valid]]
return padded, counts
[docs]
def find_pairs_periodic_grid(
cart_sorted: torch.Tensor,
atom_idx_sorted: torch.Tensor,
combo_idx_sorted: torch.Tensor,
unique_cells: torch.Tensor,
starts: torch.Tensor,
cell_lookup: torch.Tensor,
grid_dims: torch.Tensor,
cutoff: float,
identity_combo: int,
chunk_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Find all ASU-vs-everything pairs within cutoff using periodic grid.
Fully vectorized: 27 neighbor offsets × chunked batched cdist.
No Python loops over individual pairs.
Parameters
----------
cart_sorted : (E, 3) sorted Cartesian positions
atom_idx_sorted : (E,) ASU atom index per entry
combo_idx_sorted : (E,) combo index per entry
unique_cells : (C,) occupied cell flat indices
starts : (C+1,) CSR boundaries
cell_lookup : (G,) flat cell → occupied index or -1
grid_dims : (3,) grid dimensions
cutoff : float
identity_combo : int – combo index for identity (op=0, offset=0)
chunk_size : int – cell pairs per cdist batch
Returns
-------
pair_atom_i : (P,) ASU atom index (always from identity/ASU)
pair_atom_j : (P,) ASU atom index (mate source)
pair_combo_j : (P,) combo index for atom j
"""
device = cart_sorted.device
cutoff_sq = cutoff * cutoff
neighbor_offsets = _get_neighbor_offsets(device)
gy = grid_dims[1].item()
gz = grid_dims[2].item()
n_occupied = len(unique_cells)
counts = starts[1:] - starts[:-1]
max_per_cell = counts.max().item()
# Decode unique_cells back to (i, j, k)
cell_ijk = torch.stack([
unique_cells // (gy * gz),
(unique_cells % (gy * gz)) // gz,
unique_cells % gz,
], dim=1) # (C, 3)
# Precompute: which cells have ASU entries?
is_asu = combo_idx_sorted == identity_combo
# Map each sorted entry to its occupied cell index (vectorised)
cell_sizes = starts[1:] - starts[:-1]
entry_cell_idx = torch.repeat_interleave(
torch.arange(n_occupied, device=device), cell_sizes
)
# Scatter-or: if any entry in a cell is ASU, mark the cell
has_asu_per_cell = torch.zeros(n_occupied, dtype=torch.long, device=device)
has_asu_per_cell.scatter_add_(0, entry_cell_idx, is_asu.long())
has_asu_per_cell = has_asu_per_cell > 0
all_pair_atom_i = []
all_pair_atom_j = []
all_pair_combo_j = []
for offset_idx in range(27):
d = neighbor_offsets[offset_idx] # (3,)
is_self_offset = (d == 0).all().item()
# Neighbor cell indices with periodic wrapping
nb_ijk = (cell_ijk + d[None, :]) % grid_dims[None, :]
nb_flat = nb_ijk[:, 0] * (gy * gz) + nb_ijk[:, 1] * gz + nb_ijk[:, 2]
nb_occ_idx = cell_lookup[nb_flat] # (C,) -1 if empty
active = (nb_occ_idx >= 0) & has_asu_per_cell
active_idx = active.nonzero(as_tuple=True)[0]
active_nb = nb_occ_idx[active_idx]
if len(active_idx) == 0:
continue
# Process in chunks
for cs in range(0, len(active_idx), chunk_size):
ce = min(cs + chunk_size, len(active_idx))
c_batch = active_idx[cs:ce]
nb_batch = active_nb[cs:ce]
B = len(c_batch)
# Gather positions into padded tensors
src_padded, src_counts = _gather_padded(
cart_sorted, starts, c_batch, max_per_cell
)
nb_padded, nb_counts = _gather_padded(
cart_sorted, starts, nb_batch, max_per_cell
)
# Batched cdist: (B, max_per_cell, max_per_cell)
dists = torch.cdist(src_padded, nb_padded)
within = dists < cutoff # inf padding is never < cutoff
# Get all hit indices: (batch, local_i, local_j)
b_idx, local_i, local_j = within.nonzero(as_tuple=True)
if len(b_idx) == 0:
continue
# Map local indices to global sorted indices (vectorised)
c_occ = c_batch[b_idx] # occupied cell idx for source
nb_occ = nb_batch[b_idx] # occupied cell idx for neighbor
global_i = starts[c_occ] + local_i # global sorted index
global_j = starts[nb_occ] + local_j
# Bounds check (padding entries)
valid = (global_i < starts[c_occ + 1]) & (global_j < starts[nb_occ + 1])
# Source must be ASU
valid = valid & is_asu[global_i]
# Dedup: for intra-ASU pairs (both identity combo), only keep
# atom_i < atom_j to avoid counting (A,B) and (B,A).
# For symmetry pairs (j is not identity), no dedup needed
# since only one direction has an ASU source.
ai_temp = atom_idx_sorted[global_i]
aj_temp = atom_idx_sorted[global_j]
cj_temp = combo_idx_sorted[global_j]
both_asu = is_asu[global_j] # global_i is always ASU
valid = valid & (~both_asu | (ai_temp < aj_temp))
# Apply validity mask
global_i = global_i[valid]
global_j = global_j[valid]
ai = atom_idx_sorted[global_i]
aj = atom_idx_sorted[global_j]
cj = combo_idx_sorted[global_j]
# Remove true self-pairs (same atom, identity combo)
not_self = ~((ai == aj) & (cj == identity_combo))
ai = ai[not_self]
aj = aj[not_self]
cj = cj[not_self]
if len(ai) > 0:
all_pair_atom_i.append(ai)
all_pair_atom_j.append(aj)
all_pair_combo_j.append(cj)
if not all_pair_atom_i:
empty = torch.tensor([], dtype=torch.long, device=device)
return empty, empty, empty
return (
torch.cat(all_pair_atom_i),
torch.cat(all_pair_atom_j),
torch.cat(all_pair_combo_j),
)
[docs]
def find_pairs_periodic_grid_v2(
cart_sorted: torch.Tensor,
atom_idx_sorted: torch.Tensor,
combo_idx_sorted: torch.Tensor,
unique_cells: torch.Tensor,
starts: torch.Tensor,
cell_lookup: torch.Tensor,
grid_dims: torch.Tensor,
cutoff: float,
identity_combo: int,
chunk_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Optimised replacement for :func:`find_pairs_periodic_grid`.
Three things changed vs the original:
1. **14 canonical offsets** (self + 13 lex-positive). For cell pairs
where both cells contain ASU atoms, the original algorithm finds
each pair twice (once via ``+d`` and once via ``-d``) and dedups at
the end — pure waste. Processing only half the offsets eliminates
the redundant cdist work. We still cover every ASU-containing pair
by allowing the "target is ASU" case via an index swap on emission.
2. **Precomputed padded tensor.** ``padded_xyz``, ``valid_mask`` and
``asu_mask`` are built once (not per-chunk per-offset as in the
original). Per-cell metadata is also precomputed: ``cell_has_asu``,
and the flat index offsets for mapping local ``(row, col)`` hits
back to global CSR indices.
3. **Use torch.cdist** (matmul-based internally) instead of broadcast
subtract so the intermediate tensor stays cache-friendly. Chunking
is kept but with a larger default (``512``) so we pay less Python
overhead without blowing the cache.
Correctness invariant: within a hit ``(src_atom, tgt_atom)`` we emit
a pair iff ``is_asu[src] | is_asu[tgt]``, and we canonicalise so the
ASU atom is always on the ``i`` side. For intra-ASU pairs we further
enforce ``atom_i < atom_j`` after the swap so both-ASU pairs emit in
a single canonical order (matching the current v1 invariant used by
the downstream dedup).
Parameters and return signature match
:func:`find_pairs_periodic_grid`.
"""
device = cart_sorted.device
offsets14 = _get_canonical_offsets_14(device)
gy = grid_dims[1].item()
gz = grid_dims[2].item()
# ---- Precompute ----
counts = starts[1:] - starts[:-1]
max_per_cell = counts.max().item()
# Device-adaptive chunk size. The cdist tile per chunk is
# (chunk_size, max_per_cell, max_per_cell) float32 = chunk_size *
# max_per_cell**2 * 4 bytes. For sparsely-packed crystals
# (max_per_cell ~ 135) this stays well under a GB even with
# chunk_size=1M, but for densely-packed crystals (e.g. tubulin-like
# cells with max_per_cell > 600) a single 1M-cell tile blows up to
# multi-GB and dominates the LBFGS-step memory budget. Cap the tile
# at ~256 MB so the same call site stays bounded across cell sizes.
if chunk_size is None:
if device.type == "cuda":
max_tile_bytes = 256 * 1024 * 1024
per_chunk_bytes = max(1, max_per_cell * max_per_cell * 4)
chunk_size = max(64, max_tile_bytes // per_chunk_bytes)
else:
chunk_size = 128
padded_xyz, valid_mask, asu_mask = _build_padded_cells(
cart_sorted=cart_sorted,
starts=starts,
atom_idx_sorted=atom_idx_sorted,
combo_idx_sorted=combo_idx_sorted,
identity_combo=identity_combo,
max_per_cell=max_per_cell,
)
cell_has_asu = asu_mask.any(dim=1) # (C,) bool
cell_ijk = torch.stack([
unique_cells // (gy * gz),
(unique_cells % (gy * gz)) // gz,
unique_cells % gz,
], dim=1) # (C, 3)
cell_base = starts[:-1] # (C,)
all_pair_atom_i: List[torch.Tensor] = []
all_pair_atom_j: List[torch.Tensor] = []
all_pair_combo_j: List[torch.Tensor] = []
# ---- Offset sweep (14 canonical) ----
for offset_idx in range(offsets14.shape[0]):
d = offsets14[offset_idx]
is_self_offset = bool((d == 0).all().item())
nb_ijk = (cell_ijk + d[None, :]) % grid_dims[None, :]
nb_flat = (
nb_ijk[:, 0] * (gy * gz)
+ nb_ijk[:, 1] * gz
+ nb_ijk[:, 2]
)
nb_occ_idx = cell_lookup[nb_flat] # (C,) long, -1 empty
has_nb = nb_occ_idx >= 0
nb_occ_safe = nb_occ_idx.clamp(min=0)
nb_has_asu = cell_has_asu[nb_occ_safe] & has_nb
active = has_nb & (cell_has_asu | nb_has_asu)
if not active.any().item():
continue
active_src_idx = active.nonzero(as_tuple=True)[0] # (B_total,)
active_nb_idx = nb_occ_idx[active_src_idx] # (B_total,)
B_total = active_src_idx.shape[0]
# Chunk to keep cdist tiles cache-friendly on CPU without adding
# more than a handful of Python iterations per offset.
for cs in range(0, B_total, chunk_size):
ce = min(cs + chunk_size, B_total)
src_cells_chunk = active_src_idx[cs:ce] # (B,)
nb_cells_chunk = active_nb_idx[cs:ce]
src_padded = padded_xyz[src_cells_chunk] # (B, m, 3)
nb_padded = padded_xyz[nb_cells_chunk] # (B, m, 3)
src_valid = valid_mask[src_cells_chunk] # (B, m)
nb_valid = valid_mask[nb_cells_chunk] # (B, m)
src_asu = asu_mask[src_cells_chunk] # (B, m)
nb_asu = asu_mask[nb_cells_chunk] # (B, m)
# Distance matrix via matmul-based cdist. Padding is ``inf``
# so padded slots never compare within cutoff.
dists = torch.cdist(src_padded, nb_padded) # (B, m, m)
hits = (
(dists < cutoff)
& src_valid.unsqueeze(2)
& nb_valid.unsqueeze(1)
)
# Either side must be ASU.
atom_is_asu_either = src_asu.unsqueeze(2) | nb_asu.unsqueeze(1)
hits = hits & atom_is_asu_either
if is_self_offset:
# Kill diagonal; atom-order canonicalisation is done
# globally after emission (not via local r<c here, because
# local entry order within a cell need not match atom id
# order).
m = hits.shape[1]
row = torch.arange(m, device=device)
not_diag = row.view(1, m, 1) != row.view(1, 1, m)
hits = hits & not_diag
if not hits.any().item():
continue
b_idx, li, lj = hits.nonzero(as_tuple=True)
if b_idx.numel() == 0:
continue
src_cells_hit = src_cells_chunk[b_idx] # (P,)
nb_cells_hit = nb_cells_chunk[b_idx] # (P,)
g_src = cell_base[src_cells_hit] + li # (P,)
g_nb = cell_base[nb_cells_hit] + lj # (P,)
src_is_asu_pair = asu_mask[src_cells_hit, li] # (P,)
nb_is_asu_pair = asu_mask[nb_cells_hit, lj] # (P,)
# Canonicalise: ASU on side i. If only target is ASU swap;
# if both are ASU, still make sure ASU-sided sort is
# well-defined by the later atom_i < atom_j rule below.
swap_target_asu_only = nb_is_asu_pair & ~src_is_asu_pair
g_i = torch.where(swap_target_asu_only, g_nb, g_src)
g_j = torch.where(swap_target_asu_only, g_src, g_nb)
ai = atom_idx_sorted[g_i]
aj = atom_idx_sorted[g_j]
ci = combo_idx_sorted[g_i]
cj = combo_idx_sorted[g_j]
# Intra-ASU pairs (both identity combo): enforce ai < aj so
# (a,b) and (b,a) canonicalise to the same output, matching
# the v1 upper-triangle convention.
both_asu = (ci == identity_combo) & (cj == identity_combo)
swap_intra = both_asu & (ai > aj)
tmp_a = torch.where(swap_intra, aj, ai)
aj = torch.where(swap_intra, ai, aj)
ai = tmp_a
# cj for intra-ASU pair is always identity, unchanged by swap.
# Drop true self-pairs (same atom, identity combo on the j side).
not_self = ~((ai == aj) & (cj == identity_combo))
if not bool(not_self.all().item()):
ai = ai[not_self]
aj = aj[not_self]
cj = cj[not_self]
if ai.numel() > 0:
all_pair_atom_i.append(ai)
all_pair_atom_j.append(aj)
all_pair_combo_j.append(cj)
if not all_pair_atom_i:
empty = torch.tensor([], dtype=torch.long, device=device)
return empty, empty, empty
return (
torch.cat(all_pair_atom_i),
torch.cat(all_pair_atom_j),
torch.cat(all_pair_combo_j),
)
# ------------------------------------------------------------------ #
# Step 5 – filtering
# ------------------------------------------------------------------ #
[docs]
def exclusion_set_to_hash(
exclusion_set: Set[Tuple[int, int]],
max_idx: int,
device: torch.device,
) -> torch.Tensor:
"""Convert Python exclusion set to sorted hash tensor.
Hash: min(i,j) * max_idx + max(i,j), sorted for searchsorted.
"""
if not exclusion_set:
return torch.tensor([], dtype=torch.long, device=device)
arr = np.array(list(exclusion_set), dtype=np.int64)
hashes = arr[:, 0] * max_idx + arr[:, 1] # already (min, max)
hashes.sort()
return torch.tensor(hashes, dtype=torch.long, device=device)
[docs]
def filter_pairs(
pair_atom_i: torch.Tensor,
pair_atom_j: torch.Tensor,
pair_combo_j: torch.Tensor,
identity_combo: int,
excl_hash: torch.Tensor,
max_idx: int,
pdb,
inter_residue_only: bool = True,
) -> torch.Tensor:
"""Apply exclusion, residue, and altloc filters. Returns keep mask."""
device = pair_atom_i.device
N = len(pair_atom_i)
keep = torch.ones(N, dtype=torch.bool, device=device)
is_intra_asu = pair_combo_j == identity_combo
# Bonded exclusions (1-2, 1-3, 1-4) – intra-ASU only
if len(excl_hash) > 0 and is_intra_asu.any():
norm_i = torch.minimum(pair_atom_i, pair_atom_j)
norm_j = torch.maximum(pair_atom_i, pair_atom_j)
pair_hash = norm_i * max_idx + norm_j
# searchsorted: check if hash exists in sorted excl_hash
ins = torch.searchsorted(excl_hash, pair_hash)
ins = ins.clamp(max=len(excl_hash) - 1)
is_excluded = excl_hash[ins] == pair_hash
keep &= ~(is_excluded & is_intra_asu)
# Same-residue filter – intra-ASU only
if inter_residue_only:
chainid = pdb["chainid"].values
resseq = pdb["resseq"].values
ai_np = pair_atom_i.cpu().numpy()
aj_np = pair_atom_j.cpu().numpy()
same_res = (
(chainid[ai_np] == chainid[aj_np])
& (resseq[ai_np] == resseq[aj_np])
)
same_res_t = torch.tensor(same_res, dtype=torch.bool, device=device)
keep &= ~(same_res_t & is_intra_asu)
# Altloc compatibility – intra-ASU only
if "altloc" in pdb.columns:
altloc = pdb["altloc"].values.astype(str)
altloc = np.where(np.isin(altloc, ["", " "]), " ", altloc)
ai_np = pair_atom_i.cpu().numpy()
aj_np = pair_atom_j.cpu().numpy()
alt_i = altloc[ai_np]
alt_j = altloc[aj_np]
incompat = (alt_i != " ") & (alt_j != " ") & (alt_i != alt_j)
incompat_t = torch.tensor(incompat, dtype=torch.bool, device=device)
keep &= ~(incompat_t & is_intra_asu)
return keep
# ------------------------------------------------------------------ #
# Orchestrator
# ------------------------------------------------------------------ #
[docs]
@torch.no_grad()
def build_vdw_restraints_gpu(
xyz_fn,
vdw_radii_fn,
cell: "Cell",
sg: "SpaceGroup",
pdb,
exclusion_set: Set[Tuple[int, int]],
cutoff: float = 5.0,
sigma: float = 0.2,
inter_residue_only: bool = True,
verbose: int = 0,
) -> Dict[str, torch.Tensor]:
"""Build VDW restraints using GPU-native periodic grid search.
Parameters
----------
xyz_fn : callable returns (N, 3) Cartesian coordinates
vdw_radii_fn : callable returns (N,) VDW radii
cell : Cell
sg : SpaceGroup
pdb : DataFrame
exclusion_set : set of (int, int) bonded exclusion pairs
cutoff, sigma : float
inter_residue_only : bool
verbose : int
Returns
-------
dict with keys: indices, min_distances, sigmas, symop_indices, cell_offsets
"""
from torchref.symmetry.spacegroup import SpaceGroup as SG
xyz = xyz_fn()
device = xyz.device
fdtype = dtypes.float
n_asu = xyz.shape[0]
if not isinstance(sg, SG):
sg = SG(sg)
empty_result = {
"indices": torch.zeros(0, 2, dtype=torch.long, device=device),
"min_distances": torch.zeros(0, dtype=get_float_dtype(), device=device),
"sigmas": torch.zeros(0, dtype=get_float_dtype(), device=device),
"symop_indices": torch.zeros(0, dtype=torch.long, device=device),
"cell_offsets": torch.zeros(0, 3, dtype=torch.long, device=device),
}
# Step 1: prefilter symop combos
xyz_frac = cell.cartesian_to_fractional(xyz.detach().to(fdtype))
op_indices, cell_offsets_valid = prefilter_symop_offsets(
cell, sg, xyz_frac, cutoff
)
M = len(op_indices)
if verbose > 0:
print(f" Symmetry expansion: {M} valid (symop, offset) combos")
# Find the identity combo index
is_identity = (
(op_indices == 0)
& (cell_offsets_valid == 0).all(dim=1)
)
identity_indices = is_identity.nonzero(as_tuple=True)[0]
if len(identity_indices) == 0:
# Identity not in valid combos — should not happen, but add it
op_indices = torch.cat([
torch.zeros(1, dtype=torch.long, device=device), op_indices
])
cell_offsets_valid = torch.cat([
torch.zeros(1, 3, dtype=torch.long, device=device), cell_offsets_valid
])
identity_combo = 0
M = len(op_indices)
else:
identity_combo = identity_indices[0].item()
# Step 2: compute image positions + assign to grid
# Grid dims: at least 1 cell per cutoff along each axis
cell_lengths = torch.tensor([
cell.a.item(), cell.b.item(), cell.c.item()
], dtype=fdtype, device=device)
grid_dims = torch.clamp(
(cell_lengths / cutoff).long(), min=1
) # (3,)
flat_cell, atom_idx, combo_idx, cart_pos = assign_to_grid(
xyz_frac, cell, sg, op_indices, cell_offsets_valid, grid_dims
)
n_grid_total = grid_dims[0].item() * grid_dims[1].item() * grid_dims[2].item()
# Step 3: sort into cell list
sort_order, unique_cells, starts, cell_lookup = build_cell_list(
flat_cell, n_grid_total
)
cart_sorted = cart_pos[sort_order]
atom_idx_sorted = atom_idx[sort_order]
combo_idx_sorted = combo_idx[sort_order]
if verbose > 0:
n_occupied = len(unique_cells)
counts = starts[1:] - starts[:-1]
print(f" Grid: {grid_dims.tolist()}, "
f"{n_occupied}/{n_grid_total} cells occupied, "
f"max {counts.max().item()} entries/cell")
# Step 4: find pairs via periodic grid + batched cdist.
# Uses the canonical-14-offset path; this covers each unique cell
# pair exactly once (see find_pairs_periodic_grid_v2), so the raw
# output is already (nearly) dedup-free — the downstream hash dedup
# still runs as a safety net for the small number of intra-cell
# swap-canonicalisation collisions.
pair_atom_i, pair_atom_j, pair_combo_j = find_pairs_periodic_grid_v2(
cart_sorted, atom_idx_sorted, combo_idx_sorted,
unique_cells, starts, cell_lookup, grid_dims,
cutoff, identity_combo,
)
if len(pair_atom_i) == 0:
if verbose > 0:
print(" Built 0 VDW restraints")
return empty_result
# Step 5: filter
max_idx = max(n_asu, int(pair_atom_i.max().item()) + 1,
int(pair_atom_j.max().item()) + 1)
excl_hash = exclusion_set_to_hash(exclusion_set, max_idx, device)
keep = filter_pairs(
pair_atom_i, pair_atom_j, pair_combo_j,
identity_combo, excl_hash, max_idx,
pdb, inter_residue_only,
)
pair_atom_i = pair_atom_i[keep]
pair_atom_j = pair_atom_j[keep]
pair_combo_j = pair_combo_j[keep]
if len(pair_atom_i) == 0:
if verbose > 0:
print(" Built 0 VDW restraints (all filtered)")
return empty_result
# Deduplicate: keep first occurrence of each (atom_i, atom_j, combo_j)
dedup_hash = pair_atom_i * (n_asu * M) + pair_atom_j * M + pair_combo_j
_, inverse, counts = torch.unique(
dedup_hash, return_inverse=True, return_counts=True
)
# First occurrence: for each unique hash, the minimum index.
# Use the configured int dtype (int32 by default) — MPS does not support
# int64 scatter_reduce and N_pairs fits comfortably in int32.
_int_dtype = dtypes.int
inverse_i = inverse.to(_int_dtype)
perm = torch.arange(len(inverse), device=device, dtype=_int_dtype)
first_occ = torch.full(
(counts.shape[0],), len(inverse), device=device, dtype=_int_dtype
)
first_occ.scatter_reduce_(0, inverse_i, perm, reduce="amin")
first_mask = torch.zeros(len(pair_atom_i), dtype=torch.bool, device=device)
first_mask[first_occ.long()] = True
pair_atom_i = pair_atom_i[first_mask]
pair_atom_j = pair_atom_j[first_mask]
pair_combo_j = pair_combo_j[first_mask]
# Map combo_j back to symop index and cell offset
symop_indices = op_indices[pair_combo_j]
pair_cell_offsets = cell_offsets_valid[pair_combo_j]
# VDW radii
vdw_radii = vdw_radii_fn()
min_distances = vdw_radii[pair_atom_i] + vdw_radii[pair_atom_j]
# Build output
indices = torch.stack([pair_atom_i, pair_atom_j], dim=1)
result = {
"indices": indices,
"min_distances": min_distances.to(get_float_dtype()),
"sigmas": torch.full(
(len(indices),), sigma, dtype=get_float_dtype(), device=device
),
"symop_indices": symop_indices,
"cell_offsets": pair_cell_offsets,
# Cached data for forward-time H-VDW pair search
"valid_op_indices": op_indices,
"valid_cell_offsets": cell_offsets_valid,
"grid_dims": grid_dims,
"identity_combo": torch.tensor(identity_combo, dtype=torch.long, device=device),
}
if verbose > 0:
n_sym = (
(symop_indices != 0) | (pair_cell_offsets != 0).any(dim=1)
).sum().item()
print(f" Built {len(indices)} VDW restraints, {n_sym} symmetry contacts")
return result
# ------------------------------------------------------------------ #
# H-involving pair search (forward-time, called every evaluation)
# ------------------------------------------------------------------ #
[docs]
@torch.no_grad()
def find_h_vdw_pairs_gpu(
xyz_heavy: torch.Tensor,
xyz_h: torch.Tensor,
cell: "Cell",
sg: "SpaceGroup",
op_indices: torch.Tensor,
cell_offsets_valid: torch.Tensor,
grid_dims: torch.Tensor,
identity_combo: int,
n_heavy: int,
cutoff: float = 3.5,
h_excl_hash: Optional[torch.Tensor] = None,
pdb=None,
h_chainid_enc: Optional[torch.Tensor] = None,
h_resseq: Optional[torch.Tensor] = None,
inter_residue_only: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Find VDW pairs involving at least one hydrogen atom.
Uses the same periodic-grid spatial hash as the heavy-atom search but
operates on a combined (heavy + H) coordinate set and only returns
pairs where at least one participant is a hydrogen (index >= n_heavy).
Called at every forward evaluation — designed for low latency.
Parameters
----------
xyz_heavy : (N_heavy, 3) Cartesian ASU heavy-atom positions
xyz_h : (N_h, 3) Cartesian ASU hydrogen positions
cell, sg : Cell, SpaceGroup
op_indices : (M,) cached valid symop indices
cell_offsets_valid : (M, 3) cached valid cell translations
grid_dims : (3,) cached grid dimensions
identity_combo : int
n_heavy : int
Number of heavy atoms (indices 0..n_heavy-1 are heavy)
cutoff : float
Cartesian cutoff (Å), tighter than heavy-atom search
h_excl_hash : (E,) sorted long
Hash tensor for H-specific 1-2/1-3 exclusions
pdb : DataFrame
Heavy-atom pdb for same-residue filtering
h_chainid_enc : (N_h,) long
Chain ID encoding for H atoms
h_resseq : (N_h,) long
Residue sequence number for H atoms
inter_residue_only : bool
Returns
-------
pair_atom_i, pair_atom_j, pair_combo_j : each (P,) long
Indices into the combined (heavy + H) array. atom_i is always
from the ASU (identity combo).
"""
from torchref.symmetry.spacegroup import SpaceGroup as SG
device = xyz_heavy.device
fdtype = dtypes.float
if not isinstance(sg, SG):
sg = SG(sg)
# Combine heavy + H into a single coordinate set
xyz_all = torch.cat([xyz_heavy, xyz_h], dim=0) # (N_all, 3)
n_all = xyz_all.shape[0]
empty = torch.tensor([], dtype=torch.long, device=device)
if n_all == 0:
return empty, empty, empty
# Convert to fractional
xyz_frac = cell.cartesian_to_fractional(xyz_all.detach().to(fdtype))
M = op_indices.shape[0]
# Step 2: assign to grid (reusing cached grid_dims and symop combos)
flat_cell, atom_idx, combo_idx, cart_pos = assign_to_grid(
xyz_frac, cell, sg, op_indices, cell_offsets_valid, grid_dims
)
n_grid_total = grid_dims[0].item() * grid_dims[1].item() * grid_dims[2].item()
# Step 3: sort into cell list
sort_order, unique_cells, starts, cell_lookup = build_cell_list(
flat_cell, n_grid_total
)
cart_sorted = cart_pos[sort_order]
atom_idx_sorted = atom_idx[sort_order]
combo_idx_sorted = combo_idx[sort_order]
# Step 4: find pairs (use the canonical-14-offset fast path)
pair_atom_i, pair_atom_j, pair_combo_j = find_pairs_periodic_grid_v2(
cart_sorted, atom_idx_sorted, combo_idx_sorted,
unique_cells, starts, cell_lookup, grid_dims,
cutoff, identity_combo,
)
if len(pair_atom_i) == 0:
return empty, empty, empty
# Filter to keep only pairs involving at least one H
has_h = (pair_atom_i >= n_heavy) | (pair_atom_j >= n_heavy)
pair_atom_i = pair_atom_i[has_h]
pair_atom_j = pair_atom_j[has_h]
pair_combo_j = pair_combo_j[has_h]
if len(pair_atom_i) == 0:
return empty, empty, empty
# Step 5: filtering
is_intra_asu = pair_combo_j == identity_combo
# Bonded exclusions (1-2 H-parent, 1-3 H-parent_neighbor) — intra-ASU
if h_excl_hash is not None and len(h_excl_hash) > 0 and is_intra_asu.any():
max_idx = max(n_all, int(pair_atom_i.max().item()) + 1,
int(pair_atom_j.max().item()) + 1)
norm_i = torch.minimum(pair_atom_i, pair_atom_j)
norm_j = torch.maximum(pair_atom_i, pair_atom_j)
pair_hash = norm_i * max_idx + norm_j
ins = torch.searchsorted(h_excl_hash, pair_hash)
ins = ins.clamp(max=len(h_excl_hash) - 1)
is_excluded = h_excl_hash[ins] == pair_hash
keep = ~(is_excluded & is_intra_asu)
pair_atom_i = pair_atom_i[keep]
pair_atom_j = pair_atom_j[keep]
pair_combo_j = pair_combo_j[keep]
is_intra_asu = is_intra_asu[keep]
if len(pair_atom_i) == 0:
return empty, empty, empty
# Same-residue filter — intra-ASU only
if inter_residue_only and pdb is not None:
pdb_chainid = pdb["chainid"].values
pdb_resseq = pdb["resseq"].values.astype(np.int64)
# Build combined chain/resseq arrays (heavy from pdb, H from topology)
if h_chainid_enc is not None and h_resseq is not None:
# For heavy atoms, encode chain IDs consistently
chain_vals = pdb_chainid.astype(str)
unique_chains = np.unique(chain_vals)
chain_to_int = {c: i for i, c in enumerate(unique_chains)}
heavy_chain_enc = np.array([chain_to_int.get(c, -1) for c in chain_vals],
dtype=np.int64)
heavy_resseq = pdb_resseq
all_chain_enc = np.concatenate([
heavy_chain_enc,
h_chainid_enc.cpu().numpy(),
])
all_resseq = np.concatenate([
heavy_resseq,
h_resseq.cpu().numpy(),
])
else:
all_chain_enc = None
if all_chain_enc is not None:
ai_np = pair_atom_i.cpu().numpy()
aj_np = pair_atom_j.cpu().numpy()
same_res = (
(all_chain_enc[ai_np] == all_chain_enc[aj_np])
& (all_resseq[ai_np] == all_resseq[aj_np])
)
same_res_t = torch.tensor(same_res, dtype=torch.bool, device=device)
keep = ~(same_res_t & is_intra_asu)
pair_atom_i = pair_atom_i[keep]
pair_atom_j = pair_atom_j[keep]
pair_combo_j = pair_combo_j[keep]
if len(pair_atom_i) == 0:
return empty, empty, empty
# Altloc compatibility — intra-ASU, only relevant for heavy atoms
if pdb is not None and "altloc" in pdb.columns:
is_intra_asu = pair_combo_j == identity_combo
# Only check altloc for pairs where both are heavy atoms
both_heavy = (pair_atom_i < n_heavy) & (pair_atom_j < n_heavy) & is_intra_asu
if both_heavy.any():
altloc = pdb["altloc"].values.astype(str)
altloc = np.where(np.isin(altloc, ["", " "]), " ", altloc)
ai_np = pair_atom_i[both_heavy].cpu().numpy()
aj_np = pair_atom_j[both_heavy].cpu().numpy()
incompat = (altloc[ai_np] != " ") & (altloc[aj_np] != " ") & (altloc[ai_np] != altloc[aj_np])
reject = torch.zeros(len(pair_atom_i), dtype=torch.bool, device=device)
reject[both_heavy] = torch.tensor(incompat, dtype=torch.bool, device=device)
keep = ~reject
pair_atom_i = pair_atom_i[keep]
pair_atom_j = pair_atom_j[keep]
pair_combo_j = pair_combo_j[keep]
if len(pair_atom_i) == 0:
return empty, empty, empty
# Deduplicate
dedup_hash = pair_atom_i * (n_all * M) + pair_atom_j * M + pair_combo_j
_, inverse, counts = torch.unique(
dedup_hash, return_inverse=True, return_counts=True
)
# MPS does not support int64 scatter_reduce; use the configured int dtype.
_int_dtype = dtypes.int
inverse_i = inverse.to(_int_dtype)
perm = torch.arange(len(inverse), device=device, dtype=_int_dtype)
first_occ = torch.full(
(counts.shape[0],), len(inverse), device=device, dtype=_int_dtype
)
first_occ.scatter_reduce_(0, inverse_i, perm, reduce="amin")
first_mask = torch.zeros(len(pair_atom_i), dtype=torch.bool, device=device)
first_mask[first_occ.long()] = True
return pair_atom_i[first_mask], pair_atom_j[first_mask], pair_combo_j[first_mask]