torchref.base.targets.triton.bond module

Triton kernels for the bond-length Gaussian-NLL target.

Matches torchref.base.targets.bond.bond_math() to within float32 precision. Two kernels:

  • _bond_nll_fwd_kernel: per-bond Gaussian NLL.

  • _bond_nll_bwd_kernel: scatters gradients into xyz (atomic add).

The autograd.Function wrapper composes them so the result can be plugged directly into a backward graph.

torchref.base.targets.triton.bond.bond_math_triton(xyz, idx, references, sigmas)[source]

Triton-backed bond-length Gaussian NLL.

Drop-in replacement for torchref.base.targets.bond.bond_math().