torchref.base.targets.triton.place_hydrogens module

Triton forward + analytic backward for riding-hydrogen placement.

The eager helper (_place_h_jit in torchref.restraints.hydrogen_topology) is a @torch.jit.script that fuses the forward pass into ~30 launches. Its backward, however, flows through PyTorch autograd op-by-op (cross products, normalisations, where-branch routing); for ~3 k hydrogens that’s ~100 launches and dominates the non-bonded backward cost (1.9 ms out of 5.3 ms total fwd+bw on 1DAW / A100).

This kernel does both forward and analytic backward in one launch each. The math mirrors _place_h_jit exactly:

pp = xyz[parent_idx] nb_pos[i] = xyz[nb_idx[i]] v[i] = (nb_pos[i] − pp) · valid[i] s = Σ_i v[i] base = −s / (|s| + ε)

cross12 = v1 × v2 cross_norm = |cross12| perp1_cross = cross12 / (cross_norm + ε) cardinal = ê_argmin|base| (detached — no grad) perp1_ortho = (base × cardinal) / |base × cardinal| perp1 = perp1_cross if cross_norm > 1e-6 else perp1_ortho perp2 = base × perp1

direction = c0·base + c1·perp1 + c2·perp2 H = pp + L · direction

torchref.base.targets.triton.place_hydrogens.place_riding_hydrogens_triton(xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length)[source]

Triton-backed riding-hydrogen placement.

Drop-in replacement for _place_h_jit (forward) plus an analytic Triton backward that avoids autograd’s op-by-op traversal.