"""Triton kernels for the bond-length Gaussian-NLL target.
Matches :func:`torchref.base.targets.bond.bond_math` to within float32
precision. Two kernels:
* ``_bond_nll_fwd_kernel``: per-bond Gaussian NLL.
* ``_bond_nll_bwd_kernel``: scatters gradients into ``xyz`` (atomic add).
The ``autograd.Function`` wrapper composes them so the result can be
plugged directly into a backward graph.
"""
from __future__ import annotations
import math
import torch
import triton
import triton.language as tl
_LOG_2PI = float(math.log(2.0 * math.pi))
# --------------------------------------------------------------------- kernels
@triton.jit
def _bond_nll_fwd_kernel(
xyz_ptr, # (N_atoms, 3) float
idx_ptr, # (N, 2) int
ref_ptr, # (N,) float
sig_ptr, # (N,) float
out_ptr, # (N,) float -- per-bond NLL
N: tl.constexpr,
LOG_2PI: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
i = tl.load(idx_ptr + offs * 2 + 0, mask=mask, other=0)
j = tl.load(idx_ptr + offs * 2 + 1, mask=mask, other=0)
p1x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
p2x = tl.load(xyz_ptr + j * 3 + 0, mask=mask, other=0.0)
p2y = tl.load(xyz_ptr + j * 3 + 1, mask=mask, other=0.0)
p2z = tl.load(xyz_ptr + j * 3 + 2, mask=mask, other=0.0)
dx = p2x - p1x
dy = p2y - p1y
dz = p2z - p1z
d = tl.sqrt(dx * dx + dy * dy + dz * dz)
ref = tl.load(ref_ptr + offs, mask=mask, other=0.0)
sig = tl.load(sig_ptr + offs, mask=mask, other=1.0)
dev = d - ref
nll = 0.5 * (dev / sig) * (dev / sig) + tl.log(sig) + 0.5 * LOG_2PI
tl.store(out_ptr + offs, nll, mask=mask)
@triton.jit
def _bond_nll_bwd_kernel(
xyz_ptr, # (N_atoms, 3) -- input
idx_ptr, # (N, 2)
ref_ptr, # (N,)
sig_ptr, # (N,)
grad_out_ptr, # 0-D tensor -- gradient of upstream loss w.r.t. sum(nll).
# Read in-kernel so the host doesn't need to ``.item()``
# it (which would force a cuStreamSynchronize and block
# the host from queuing subsequent kernel launches).
dxyz_ptr, # (N_atoms, 3) -- gradient accumulator (atomic add)
N: tl.constexpr,
BLOCK: tl.constexpr,
):
"""Gradient of NLL_i wrt p1, p2.
Let u_i = (p2_i - p1_i), d_i = ||u_i||, dev_i = d_i - ref_i. Then
dNLL_i/dp2_i = (dev_i / sig_i^2) * (u_i / d_i)
dNLL_i/dp1_i = -dNLL_i/dp2_i
Multiplied by grad_out (= dLoss/dNLL_summed = scalar, read from
``grad_out_ptr`` at runtime).
"""
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
# One scalar load per block, broadcast across the threads.
grad_out = tl.load(grad_out_ptr)
i = tl.load(idx_ptr + offs * 2 + 0, mask=mask, other=0)
j = tl.load(idx_ptr + offs * 2 + 1, mask=mask, other=0)
p1x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
p2x = tl.load(xyz_ptr + j * 3 + 0, mask=mask, other=0.0)
p2y = tl.load(xyz_ptr + j * 3 + 1, mask=mask, other=0.0)
p2z = tl.load(xyz_ptr + j * 3 + 2, mask=mask, other=0.0)
ux = p2x - p1x
uy = p2y - p1y
uz = p2z - p1z
d = tl.sqrt(ux * ux + uy * uy + uz * uz)
d_safe = tl.where(d > 0, d, 1.0) # avoid 0/0 at degenerate pairs
ref = tl.load(ref_ptr + offs, mask=mask, other=0.0)
sig = tl.load(sig_ptr + offs, mask=mask, other=1.0)
dev = d - ref
coef = grad_out * dev / (sig * sig) / d_safe # scalar per bond
gx = coef * ux
gy = coef * uy
gz = coef * uz
# Scatter-add into the gradient buffer. Two atoms per bond, opposite signs.
tl.atomic_add(dxyz_ptr + j * 3 + 0, gx, mask=mask)
tl.atomic_add(dxyz_ptr + j * 3 + 1, gy, mask=mask)
tl.atomic_add(dxyz_ptr + j * 3 + 2, gz, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 0, -gx, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 1, -gy, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 2, -gz, mask=mask)
# --------------------------------------------------------------- autograd wrap
class _BondMathTriton(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, idx: torch.Tensor,
references: torch.Tensor, sigmas: torch.Tensor) -> torch.Tensor:
assert xyz.is_cuda and idx.is_cuda
assert xyz.dtype == torch.float32
# xyz is assumed to be row-major (N_atoms, 3) with stride (3, 1).
# The fix lives in MixedTensor.__init__ — see parameter_wrappers.py.
N = idx.shape[0]
nll = torch.empty(N, dtype=xyz.dtype, device=xyz.device)
BLOCK = 256
grid = (triton.cdiv(N, BLOCK),)
_bond_nll_fwd_kernel[grid](
xyz, idx, references, sigmas, nll,
N=N, LOG_2PI=_LOG_2PI, BLOCK=BLOCK,
)
ctx.save_for_backward(xyz, idx, references, sigmas)
return nll.sum()
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
xyz, idx, references, sigmas = ctx.saved_tensors
N = idx.shape[0]
dxyz = torch.zeros_like(xyz)
BLOCK = 256
grid = (triton.cdiv(N, BLOCK),)
# Pass grad_out as a 0-D device tensor (its data_ptr). The kernel
# ``tl.load``s it once per block — no ``.item()``, no host sync.
_bond_nll_bwd_kernel[grid](
xyz, idx, references, sigmas, grad_out, dxyz,
N=N, BLOCK=BLOCK,
)
return dxyz, None, None, None
[docs]
def bond_math_triton(
xyz: torch.Tensor,
idx: torch.Tensor,
references: torch.Tensor,
sigmas: torch.Tensor,
) -> torch.Tensor:
"""Triton-backed bond-length Gaussian NLL.
Drop-in replacement for :func:`torchref.base.targets.bond.bond_math`.
"""
return _BondMathTriton.apply(xyz, idx, references, sigmas)