Source code for torchref.base.targets.triton.angle

"""Triton forward + backward for the bond-angle Gaussian-NLL target."""

from __future__ import annotations

import math

import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice


_LOG_2PI = float(math.log(2.0 * math.pi))
_EPS = 1e-8


@triton.jit
def _angle_nll_fwd_kernel(
    xyz_ptr,
    idx_ptr,
    ref_ptr,
    sig_ptr,
    out_ptr,
    N: tl.constexpr,
    LOG_2PI: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N

    a = tl.load(idx_ptr + offs * 3 + 0, mask=mask, other=0)
    b = tl.load(idx_ptr + offs * 3 + 1, mask=mask, other=0)
    c = tl.load(idx_ptr + offs * 3 + 2, mask=mask, other=0)

    ax = tl.load(xyz_ptr + a * 3 + 0, mask=mask, other=0.0)
    ay = tl.load(xyz_ptr + a * 3 + 1, mask=mask, other=0.0)
    az = tl.load(xyz_ptr + a * 3 + 2, mask=mask, other=0.0)
    bx = tl.load(xyz_ptr + b * 3 + 0, mask=mask, other=0.0)
    by = tl.load(xyz_ptr + b * 3 + 1, mask=mask, other=0.0)
    bz = tl.load(xyz_ptr + b * 3 + 2, mask=mask, other=0.0)
    cx = tl.load(xyz_ptr + c * 3 + 0, mask=mask, other=0.0)
    cy = tl.load(xyz_ptr + c * 3 + 1, mask=mask, other=0.0)
    cz = tl.load(xyz_ptr + c * 3 + 2, mask=mask, other=0.0)

    v1x = ax - bx; v1y = ay - by; v1z = az - bz
    v2x = cx - bx; v2y = cy - by; v2z = cz - bz

    n1 = tl.sqrt(v1x * v1x + v1y * v1y + v1z * v1z)
    n2 = tl.sqrt(v2x * v2x + v2y * v2y + v2z * v2z)
    dot = v1x * v2x + v1y * v2y + v1z * v2z
    cos_t = dot / (n1 * n2)
    cos_t = tl.where(cos_t > 1.0, 1.0, cos_t)
    cos_t = tl.where(cos_t < -1.0, -1.0, cos_t)
    theta = libdevice.acos(cos_t)

    ref = tl.load(ref_ptr + offs, mask=mask, other=0.0)
    sig = tl.load(sig_ptr + offs, mask=mask, other=1.0)
    dev = theta - ref
    nll = 0.5 * (dev / sig) * (dev / sig) + tl.log(sig) + 0.5 * LOG_2PI

    tl.store(out_ptr + offs, nll, mask=mask)


@triton.jit
def _angle_nll_bwd_kernel(
    xyz_ptr,
    idx_ptr,
    ref_ptr,
    sig_ptr,
    grad_out_ptr,  # 0-D tensor — loaded in-kernel (avoid host .item() sync)
    dxyz_ptr,
    N: tl.constexpr,
    EPS: tl.constexpr,
    BLOCK: tl.constexpr,
):
    """Analytic gradient of the angle NLL.

    With v1 = a - b, v2 = c - b, n1 = |v1|, n2 = |v2|,
    cos θ = (v1·v2)/(n1 n2), θ = acos(cos θ):

        ∂NLL/∂θ = (θ − ref) / σ²
        ∂θ/∂(cos θ) = −1 / sin θ
        ∂(cos θ)/∂v1 = v2/(n1 n2) − (cos θ / n1²) v1
        ∂(cos θ)/∂v2 = v1/(n1 n2) − (cos θ / n2²) v2

    Then ∂NLL/∂a = ∂NLL/∂v1, ∂NLL/∂c = ∂NLL/∂v2,
        ∂NLL/∂b = −(∂NLL/∂a + ∂NLL/∂c). Multiplied by grad_out
        (loaded once per block from ``grad_out_ptr``).
    """
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    grad_out = tl.load(grad_out_ptr)

    a = tl.load(idx_ptr + offs * 3 + 0, mask=mask, other=0)
    b = tl.load(idx_ptr + offs * 3 + 1, mask=mask, other=0)
    c = tl.load(idx_ptr + offs * 3 + 2, mask=mask, other=0)

    ax = tl.load(xyz_ptr + a * 3 + 0, mask=mask, other=0.0)
    ay = tl.load(xyz_ptr + a * 3 + 1, mask=mask, other=0.0)
    az = tl.load(xyz_ptr + a * 3 + 2, mask=mask, other=0.0)
    bx = tl.load(xyz_ptr + b * 3 + 0, mask=mask, other=0.0)
    by = tl.load(xyz_ptr + b * 3 + 1, mask=mask, other=0.0)
    bz = tl.load(xyz_ptr + b * 3 + 2, mask=mask, other=0.0)
    cx = tl.load(xyz_ptr + c * 3 + 0, mask=mask, other=0.0)
    cy = tl.load(xyz_ptr + c * 3 + 1, mask=mask, other=0.0)
    cz = tl.load(xyz_ptr + c * 3 + 2, mask=mask, other=0.0)

    v1x = ax - bx; v1y = ay - by; v1z = az - bz
    v2x = cx - bx; v2y = cy - by; v2z = cz - bz
    n1_sq = v1x * v1x + v1y * v1y + v1z * v1z
    n2_sq = v2x * v2x + v2y * v2y + v2z * v2z
    n1 = tl.sqrt(n1_sq)
    n2 = tl.sqrt(n2_sq)
    dot = v1x * v2x + v1y * v2y + v1z * v2z
    cos_t = dot / (n1 * n2)
    cos_t = tl.where(cos_t > 1.0, 1.0, cos_t)
    cos_t = tl.where(cos_t < -1.0, -1.0, cos_t)
    sin_t = tl.sqrt(tl.maximum(1.0 - cos_t * cos_t, EPS))
    theta = libdevice.acos(cos_t)

    ref = tl.load(ref_ptr + offs, mask=mask, other=0.0)
    sig = tl.load(sig_ptr + offs, mask=mask, other=1.0)
    dev = theta - ref

    dnll_dtheta = dev / (sig * sig)
    coef = -dnll_dtheta / sin_t * grad_out

    inv_n1n2 = 1.0 / (n1 * n2)
    ct_over_n1sq = cos_t / n1_sq
    ct_over_n2sq = cos_t / n2_sq

    # ∂NLL/∂v1 = coef * (v2 * inv_n1n2 - v1 * ct_over_n1sq)
    dv1x = coef * (v2x * inv_n1n2 - v1x * ct_over_n1sq)
    dv1y = coef * (v2y * inv_n1n2 - v1y * ct_over_n1sq)
    dv1z = coef * (v2z * inv_n1n2 - v1z * ct_over_n1sq)
    dv2x = coef * (v1x * inv_n1n2 - v2x * ct_over_n2sq)
    dv2y = coef * (v1y * inv_n1n2 - v2y * ct_over_n2sq)
    dv2z = coef * (v1z * inv_n1n2 - v2z * ct_over_n2sq)

    # Scatter: a += dv1, c += dv2, b -= (dv1 + dv2)
    tl.atomic_add(dxyz_ptr + a * 3 + 0, dv1x, mask=mask)
    tl.atomic_add(dxyz_ptr + a * 3 + 1, dv1y, mask=mask)
    tl.atomic_add(dxyz_ptr + a * 3 + 2, dv1z, mask=mask)
    tl.atomic_add(dxyz_ptr + c * 3 + 0, dv2x, mask=mask)
    tl.atomic_add(dxyz_ptr + c * 3 + 1, dv2y, mask=mask)
    tl.atomic_add(dxyz_ptr + c * 3 + 2, dv2z, mask=mask)
    tl.atomic_add(dxyz_ptr + b * 3 + 0, -(dv1x + dv2x), mask=mask)
    tl.atomic_add(dxyz_ptr + b * 3 + 1, -(dv1y + dv2y), mask=mask)
    tl.atomic_add(dxyz_ptr + b * 3 + 2, -(dv1z + dv2z), mask=mask)


class _AngleMathTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz, idx, references_rad, sigmas_rad):
        assert xyz.is_cuda and xyz.dtype == torch.float32
        N = idx.shape[0]
        nll = torch.empty(N, dtype=xyz.dtype, device=xyz.device)
        BLOCK = 256
        grid = (triton.cdiv(N, BLOCK),)
        _angle_nll_fwd_kernel[grid](
            xyz, idx, references_rad, sigmas_rad, nll,
            N=N, LOG_2PI=_LOG_2PI, BLOCK=BLOCK,
        )
        ctx.save_for_backward(xyz, idx, references_rad, sigmas_rad)
        return nll.sum()

    @staticmethod
    def backward(ctx, grad_out):
        xyz, idx, refs, sigs = ctx.saved_tensors
        N = idx.shape[0]
        dxyz = torch.zeros_like(xyz)
        BLOCK = 256
        grid = (triton.cdiv(N, BLOCK),)
        _angle_nll_bwd_kernel[grid](
            xyz, idx, refs, sigs, grad_out, dxyz,
            N=N, EPS=_EPS, BLOCK=BLOCK,
        )
        return dxyz, None, None, None


[docs] def angle_math_triton(xyz, idx, references_rad, sigmas_rad): """Triton-backed bond-angle Gaussian NLL with analytic backward. Drop-in replacement for :func:`torchref.base.targets.angle.angle_math`. """ return _AngleMathTriton.apply(xyz, idx, references_rad, sigmas_rad)