torchref.restraints.neighbor_search module
GPU-native periodic neighbor search for VDW restraints.
Works in fractional space with periodic boundary conditions.
Avoids explicit symmetry expansion by assigning (atom, symop+offset)
entries to grid cells and using padded batched torch.cdist
for distance computation.
All operations run under torch.no_grad() on whatever device
the input coordinates live on (CPU or GPU).
- torchref.restraints.neighbor_search.prefilter_symop_offsets(cell, sg, xyz_frac, cutoff)[source]
Select (symop, cell_offset) combos that could produce contacts.
Uses the ASU centroid and molecule radius to eliminate obviously distant combinations. Always includes identity (op=0, offset=0).
- Parameters:
cell (Cell, SpaceGroup)
sg (Cell, SpaceGroup)
xyz_frac ((N, 3) fractional ASU coordinates)
cutoff (float Cartesian cutoff in Angstrom)
- Returns:
op_indices ((M,) long – symop indices for each valid combo)
cell_offsets ((M, 3) long – integer cell translations)
- Return type:
- torchref.restraints.neighbor_search.assign_to_grid(xyz_frac, cell, sg, op_indices, cell_offsets, grid_dims)[source]
Compute Cartesian image positions and assign to grid cells.
- Parameters:
xyz_frac ((N, 3))
cell (Cell)
sg (SpaceGroup)
op_indices ((M,) long)
cell_offsets ((M, 3) long)
grid_dims ((3,) long – number of grid cells per axis)
- Returns:
flat_cell ((N*M,) long – flat grid cell index per entry)
atom_idx ((N*M,) long – ASU atom index)
combo_idx ((N*M,) long – index into op_indices / cell_offsets)
cart_pos ((N*M, 3) float – Cartesian positions (reused in step 4))
- Return type:
- torchref.restraints.neighbor_search.build_cell_list(flat_cell, n_grid_total)[source]
Sort entries by grid cell and build CSR boundary arrays.
- torchref.restraints.neighbor_search.find_pairs_periodic_grid(cart_sorted, atom_idx_sorted, combo_idx_sorted, unique_cells, starts, cell_lookup, grid_dims, cutoff, identity_combo, chunk_size=128)[source]
Find all ASU-vs-everything pairs within cutoff using periodic grid.
Fully vectorized: 27 neighbor offsets × chunked batched cdist. No Python loops over individual pairs.
- Parameters:
cart_sorted ((E, 3) sorted Cartesian positions)
atom_idx_sorted ((E,) ASU atom index per entry)
combo_idx_sorted ((E,) combo index per entry)
unique_cells ((C,) occupied cell flat indices)
starts ((C+1,) CSR boundaries)
cell_lookup ((G,) flat cell → occupied index or -1)
grid_dims ((3,) grid dimensions)
cutoff (float)
identity_combo (int – combo index for identity (op=0, offset=0))
chunk_size (int – cell pairs per cdist batch)
- Returns:
pair_atom_i ((P,) ASU atom index (always from identity/ASU))
pair_atom_j ((P,) ASU atom index (mate source))
pair_combo_j ((P,) combo index for atom j)
- Return type:
- torchref.restraints.neighbor_search.find_pairs_periodic_grid_v2(cart_sorted, atom_idx_sorted, combo_idx_sorted, unique_cells, starts, cell_lookup, grid_dims, cutoff, identity_combo, chunk_size=None)[source]
Optimised replacement for
find_pairs_periodic_grid().Three things changed vs the original:
14 canonical offsets (self + 13 lex-positive). For cell pairs where both cells contain ASU atoms, the original algorithm finds each pair twice (once via
+dand once via-d) and dedups at the end — pure waste. Processing only half the offsets eliminates the redundant cdist work. We still cover every ASU-containing pair by allowing the “target is ASU” case via an index swap on emission.Precomputed padded tensor.
padded_xyz,valid_maskandasu_maskare built once (not per-chunk per-offset as in the original). Per-cell metadata is also precomputed:cell_has_asu, and the flat index offsets for mapping local(row, col)hits back to global CSR indices.Use torch.cdist (matmul-based internally) instead of broadcast subtract so the intermediate tensor stays cache-friendly. Chunking is kept but with a larger default (
512) so we pay less Python overhead without blowing the cache.
Correctness invariant: within a hit
(src_atom, tgt_atom)we emit a pair iffis_asu[src] | is_asu[tgt], and we canonicalise so the ASU atom is always on theiside. For intra-ASU pairs we further enforceatom_i < atom_jafter the swap so both-ASU pairs emit in a single canonical order (matching the current v1 invariant used by the downstream dedup).Parameters and return signature match
find_pairs_periodic_grid().
- torchref.restraints.neighbor_search.exclusion_set_to_hash(exclusion_set, max_idx, device)[source]
Convert Python exclusion set to sorted hash tensor.
Hash: min(i,j) * max_idx + max(i,j), sorted for searchsorted.
- torchref.restraints.neighbor_search.filter_pairs(pair_atom_i, pair_atom_j, pair_combo_j, identity_combo, excl_hash, max_idx, pdb, inter_residue_only=True)[source]
Apply exclusion, residue, and altloc filters. Returns keep mask.
- torchref.restraints.neighbor_search.build_vdw_restraints_gpu(xyz_fn, vdw_radii_fn, cell, sg, pdb, exclusion_set, cutoff=5.0, sigma=0.2, inter_residue_only=True, verbose=0)[source]
Build VDW restraints using GPU-native periodic grid search.
- Parameters:
- Returns:
dict with keys
- Return type:
indices, min_distances, sigmas, symop_indices, cell_offsets
- torchref.restraints.neighbor_search.find_h_vdw_pairs_gpu(xyz_heavy, xyz_h, cell, sg, op_indices, cell_offsets_valid, grid_dims, identity_combo, n_heavy, cutoff=3.5, h_excl_hash=None, pdb=None, h_chainid_enc=None, h_resseq=None, inter_residue_only=True)[source]
Find VDW pairs involving at least one hydrogen atom.
Uses the same periodic-grid spatial hash as the heavy-atom search but operates on a combined (heavy + H) coordinate set and only returns pairs where at least one participant is a hydrogen (index >= n_heavy).
Called at every forward evaluation — designed for low latency.
- Parameters:
xyz_heavy ((N_heavy, 3) Cartesian ASU heavy-atom positions)
xyz_h ((N_h, 3) Cartesian ASU hydrogen positions)
cell (Cell, SpaceGroup)
sg (Cell, SpaceGroup)
op_indices ((M,) cached valid symop indices)
cell_offsets_valid ((M, 3) cached valid cell translations)
grid_dims ((3,) cached grid dimensions)
identity_combo (int)
n_heavy (int) – Number of heavy atoms (indices 0..n_heavy-1 are heavy)
cutoff (float) – Cartesian cutoff (Å), tighter than heavy-atom search
h_excl_hash ((E,) sorted long) – Hash tensor for H-specific 1-2/1-3 exclusions
pdb (DataFrame) – Heavy-atom pdb for same-residue filtering
h_chainid_enc ((N_h,) long) – Chain ID encoding for H atoms
h_resseq ((N_h,) long) – Residue sequence number for H atoms
inter_residue_only (bool)
- Returns:
pair_atom_i, pair_atom_j, pair_combo_j – Indices into the combined (heavy + H) array. atom_i is always from the ASU (identity combo).
- Return type:
each (P,) long