torchref.restraints.hydrogen_topology module

Riding hydrogen topology and vectorized placement for VDW restraints.

Builds a static topology map at restraints-construction time that describes how to generate transient hydrogen atom positions from heavy-atom coordinates. At each VDW evaluation the place_riding_hydrogens function produces H positions in a single vectorized pass (no Python loops over atoms).

Hydrogen positions are fully determined by the parent heavy atom and its bonded heavy-atom neighbours, so gradients flow from the VDW loss through the H positions back to the heavy-atom coordinates via standard autograd.

class torchref.restraints.hydrogen_topology.HydrogenTopology[source]

Bases: DeviceMixin, Module

Static topology describing riding hydrogens for VDW evaluation.

All data are stored as registered buffers so they move automatically with .to(device) and appear in state_dict.

h_parent_idx

Index into heavy-atom array for each riding H.

Type:

(N_h,) long

h_bond_length

Ideal H–parent bond length (Å).

Type:

(N_h,) float

h_vdw_radius

Van der Waals radius for each H (1.20 Å).

Type:

(N_h,) float

h_placement_type

Placement-geometry enum (see module-level constants).

Type:

(N_h,) long

h_slot_in_parent

Ordinal within sibling H atoms on the same parent (0, 1, 2).

Type:

(N_h,) long

parent_neighbor_idx

Heavy-atom neighbour indices of the parent (-1 = padding).

Type:

(N_h, MAX_HEAVY_NB) long

parent_neighbor_count

Actual number of heavy-atom neighbours for the parent.

Type:

(N_h,) long

h_chainid_enc

Encoded chain ID (for same-residue filtering).

Type:

(N_h,) long

h_resseq

Residue sequence number (for same-residue filtering).

Type:

(N_h,) long

__init__()[source]
property n_hydrogens: int
property has_candidates: bool

Whether precomputed H candidate pairs are available.

torchref.restraints.hydrogen_topology.build_hydrogen_topology(pdb, device=device(type='cpu'), verbose=0)[source]

Build riding-hydrogen topology from the model’s heavy-atom DataFrame.

Parameters:
  • pdb (pd.DataFrame) – Heavy-atom DataFrame (strip_H=True).

  • device (torch.device) – Target device for tensors.

  • verbose (int) – Verbosity level.

Returns:

Module with registered buffer tensors.

Return type:

HydrogenTopology

torchref.restraints.hydrogen_topology.place_riding_hydrogens(xyz_heavy, topo)[source]

Generate riding hydrogen positions from heavy-atom coordinates.

Delegates to a JIT-compiled kernel that fuses element-wise ops, reducing GPU kernel launches from ~230 to ~30.

Parameters:
  • xyz_heavy ((N_heavy, 3) float tensor (requires_grad typically True))

  • topo (HydrogenTopology)

Returns:

xyz_h

Return type:

(N_h, 3) float tensor, differentiable w.r.t. xyz_heavy

torchref.restraints.hydrogen_topology.build_h_candidate_pairs(h_topo, vdw_data, pdb, h_excl_hash, device=device(type='cpu'), verbose=0)[source]

Precompute candidate H-involving VDW pairs from heavy-atom pair list.

For each heavy-heavy VDW pair (A, B, symop, offset), derives candidate H-heavy pairs where H rides on A and could interact with B (or vice versa). Applies exclusion and same-residue filters at build time so that the forward pass only needs to compute distances and energy.

Results are stored as registered buffers on h_topo:

  • cand_idx_i (C,) long — first atom (combined index)

  • cand_idx_j (C,) long — second atom (combined index)

  • cand_symop_idx (C,) long — symop index for the heavy atom

  • cand_cell_offset (C, 3) long — cell translation for the heavy atom

  • cand_min_dist (C,) float — VDW radius sum (H + heavy)

Parameters:
  • h_topo (HydrogenTopology)

  • vdw_data (dict) – Output of build_vdw_restraints_gpu (keys: indices, symop_indices, cell_offsets, etc.).

  • pdb (DataFrame) – Heavy-atom DataFrame.

  • h_excl_hash ((E,) long) – Sorted exclusion hash tensor for H-specific 1-2/1-3 pairs.

  • device (torch.device)

  • verbose (int)