torchref.base.targets.triton.bond module
Triton kernels for the bond-length Gaussian-NLL target.
Matches torchref.base.targets.bond.bond_math() to within float32
precision. Two kernels:
_bond_nll_fwd_kernel: per-bond Gaussian NLL.
_bond_nll_bwd_kernel: scatters gradients intoxyz(atomic add).
The autograd.Function wrapper composes them so the result can be
plugged directly into a backward graph.
- torchref.base.targets.triton.bond.bond_math_triton(xyz, idx, references, sigmas)[source]
Triton-backed bond-length Gaussian NLL.
Drop-in replacement for
torchref.base.targets.bond.bond_math().