Source code for torchref.base.targets.triton.chiral

"""Triton forward + backward for the chiral-volume target."""

from __future__ import annotations

import math

import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice


_LOG_2PI = float(math.log(2.0 * math.pi))


@triton.jit
def _chiral_nll_fwd_kernel(
    xyz_ptr,
    idx_ptr,        # (N, 4) [center, a1, a2, a3]
    ideal_ptr,      # (N,)
    sig_ptr,        # (N,)
    out_ptr,        # (N,)
    N: tl.constexpr,
    LOG_2PI: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N

    c0 = tl.load(idx_ptr + offs * 4 + 0, mask=mask, other=0)
    i1 = tl.load(idx_ptr + offs * 4 + 1, mask=mask, other=0)
    i2 = tl.load(idx_ptr + offs * 4 + 2, mask=mask, other=0)
    i3 = tl.load(idx_ptr + offs * 4 + 3, mask=mask, other=0)

    pcx = tl.load(xyz_ptr + c0 * 3 + 0, mask=mask, other=0.0)
    pcy = tl.load(xyz_ptr + c0 * 3 + 1, mask=mask, other=0.0)
    pcz = tl.load(xyz_ptr + c0 * 3 + 2, mask=mask, other=0.0)

    v1x = tl.load(xyz_ptr + i1 * 3 + 0, mask=mask, other=0.0) - pcx
    v1y = tl.load(xyz_ptr + i1 * 3 + 1, mask=mask, other=0.0) - pcy
    v1z = tl.load(xyz_ptr + i1 * 3 + 2, mask=mask, other=0.0) - pcz
    v2x = tl.load(xyz_ptr + i2 * 3 + 0, mask=mask, other=0.0) - pcx
    v2y = tl.load(xyz_ptr + i2 * 3 + 1, mask=mask, other=0.0) - pcy
    v2z = tl.load(xyz_ptr + i2 * 3 + 2, mask=mask, other=0.0) - pcz
    v3x = tl.load(xyz_ptr + i3 * 3 + 0, mask=mask, other=0.0) - pcx
    v3y = tl.load(xyz_ptr + i3 * 3 + 1, mask=mask, other=0.0) - pcy
    v3z = tl.load(xyz_ptr + i3 * 3 + 2, mask=mask, other=0.0) - pcz

    # cross23 = v2 x v3
    cx = v2y * v3z - v2z * v3y
    cy = v2z * v3x - v2x * v3z
    cz = v2x * v3y - v2y * v3x
    vol = v1x * cx + v1y * cy + v1z * cz

    ideal = tl.load(ideal_ptr + offs, mask=mask, other=0.0)
    sig = tl.load(sig_ptr + offs, mask=mask, other=1.0)

    achiral = ideal == 0.0
    eff_ideal = tl.where(achiral, 2.5, ideal)
    eff_vol = tl.where(achiral, libdevice.abs(vol), vol)
    dev = eff_vol - eff_ideal

    nll = 0.5 * (dev / sig) * (dev / sig) + tl.log(sig) + 0.5 * LOG_2PI
    tl.store(out_ptr + offs, nll, mask=mask)


@triton.jit
def _chiral_nll_bwd_kernel(
    xyz_ptr,
    idx_ptr,
    ideal_ptr,
    sig_ptr,
    grad_out_ptr,  # 0-D tensor — loaded in-kernel (no host .item() sync)
    dxyz_ptr,
    N: tl.constexpr,
    BLOCK: tl.constexpr,
):
    """Gradient of the chiral NLL.

    V = v1·(v2 × v3), vi = pi − pc. Then
        ∂V/∂p1 = v2 × v3
        ∂V/∂p2 = v3 × v1
        ∂V/∂p3 = v1 × v2
        ∂V/∂pc = −(∂V/∂p1 + ∂V/∂p2 + ∂V/∂p3)

    eff_vol = |V|   if achiral else V    ⇒ ∂eff_vol/∂V = sign(V) or 1
    dNLL/d(eff_vol) = (eff_vol − eff_ideal) / σ²

    ``grad_out_ptr`` is a 0-D tensor pointer; we load it once per block
    so the host doesn't need a ``.item()`` synchronize.
    """
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    grad_out = tl.load(grad_out_ptr)

    c0 = tl.load(idx_ptr + offs * 4 + 0, mask=mask, other=0)
    i1 = tl.load(idx_ptr + offs * 4 + 1, mask=mask, other=0)
    i2 = tl.load(idx_ptr + offs * 4 + 2, mask=mask, other=0)
    i3 = tl.load(idx_ptr + offs * 4 + 3, mask=mask, other=0)

    pcx = tl.load(xyz_ptr + c0 * 3 + 0, mask=mask, other=0.0)
    pcy = tl.load(xyz_ptr + c0 * 3 + 1, mask=mask, other=0.0)
    pcz = tl.load(xyz_ptr + c0 * 3 + 2, mask=mask, other=0.0)

    v1x = tl.load(xyz_ptr + i1 * 3 + 0, mask=mask, other=0.0) - pcx
    v1y = tl.load(xyz_ptr + i1 * 3 + 1, mask=mask, other=0.0) - pcy
    v1z = tl.load(xyz_ptr + i1 * 3 + 2, mask=mask, other=0.0) - pcz
    v2x = tl.load(xyz_ptr + i2 * 3 + 0, mask=mask, other=0.0) - pcx
    v2y = tl.load(xyz_ptr + i2 * 3 + 1, mask=mask, other=0.0) - pcy
    v2z = tl.load(xyz_ptr + i2 * 3 + 2, mask=mask, other=0.0) - pcz
    v3x = tl.load(xyz_ptr + i3 * 3 + 0, mask=mask, other=0.0) - pcx
    v3y = tl.load(xyz_ptr + i3 * 3 + 1, mask=mask, other=0.0) - pcy
    v3z = tl.load(xyz_ptr + i3 * 3 + 2, mask=mask, other=0.0) - pcz

    cx23 = v2y * v3z - v2z * v3y
    cy23 = v2z * v3x - v2x * v3z
    cz23 = v2x * v3y - v2y * v3x
    vol = v1x * cx23 + v1y * cy23 + v1z * cz23

    ideal = tl.load(ideal_ptr + offs, mask=mask, other=0.0)
    sig = tl.load(sig_ptr + offs, mask=mask, other=1.0)

    achiral = ideal == 0.0
    eff_ideal = tl.where(achiral, 2.5, ideal)
    eff_vol = tl.where(achiral, libdevice.abs(vol), vol)
    dev = eff_vol - eff_ideal

    # d eff_vol / dV  ∈ {sign(vol), 1.0}
    sign_v = tl.where(vol >= 0.0, 1.0, -1.0)
    dEff_dV = tl.where(achiral, sign_v, 1.0)
    coef = grad_out * dev / (sig * sig) * dEff_dV  # = dL/dV per chiral center

    # ∂V/∂p1 = v2 x v3 = (cx23, cy23, cz23)
    g1x = coef * cx23; g1y = coef * cy23; g1z = coef * cz23
    # ∂V/∂p2 = v3 x v1
    g2x = coef * (v3y * v1z - v3z * v1y)
    g2y = coef * (v3z * v1x - v3x * v1z)
    g2z = coef * (v3x * v1y - v3y * v1x)
    # ∂V/∂p3 = v1 x v2
    g3x = coef * (v1y * v2z - v1z * v2y)
    g3y = coef * (v1z * v2x - v1x * v2z)
    g3z = coef * (v1x * v2y - v1y * v2x)

    tl.atomic_add(dxyz_ptr + i1 * 3 + 0, g1x, mask=mask)
    tl.atomic_add(dxyz_ptr + i1 * 3 + 1, g1y, mask=mask)
    tl.atomic_add(dxyz_ptr + i1 * 3 + 2, g1z, mask=mask)
    tl.atomic_add(dxyz_ptr + i2 * 3 + 0, g2x, mask=mask)
    tl.atomic_add(dxyz_ptr + i2 * 3 + 1, g2y, mask=mask)
    tl.atomic_add(dxyz_ptr + i2 * 3 + 2, g2z, mask=mask)
    tl.atomic_add(dxyz_ptr + i3 * 3 + 0, g3x, mask=mask)
    tl.atomic_add(dxyz_ptr + i3 * 3 + 1, g3y, mask=mask)
    tl.atomic_add(dxyz_ptr + i3 * 3 + 2, g3z, mask=mask)
    # center: -(g1 + g2 + g3)
    tl.atomic_add(dxyz_ptr + c0 * 3 + 0, -(g1x + g2x + g3x), mask=mask)
    tl.atomic_add(dxyz_ptr + c0 * 3 + 1, -(g1y + g2y + g3y), mask=mask)
    tl.atomic_add(dxyz_ptr + c0 * 3 + 2, -(g1z + g2z + g3z), mask=mask)


class _ChiralMathTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz, indices, ideal_volumes, sigmas):
        assert xyz.is_cuda and xyz.dtype == torch.float32
        N = indices.shape[0]
        nll = torch.empty(N, dtype=xyz.dtype, device=xyz.device)
        BLOCK = 256
        grid = (triton.cdiv(N, BLOCK),)
        _chiral_nll_fwd_kernel[grid](
            xyz, indices, ideal_volumes, sigmas, nll,
            N=N, LOG_2PI=_LOG_2PI, BLOCK=BLOCK,
        )
        ctx.save_for_backward(xyz, indices, ideal_volumes, sigmas)
        return nll.sum()

    @staticmethod
    def backward(ctx, grad_out):
        xyz, idx, ideal, sigs = ctx.saved_tensors
        N = idx.shape[0]
        dxyz = torch.zeros_like(xyz)
        BLOCK = 256
        grid = (triton.cdiv(N, BLOCK),)
        _chiral_nll_bwd_kernel[grid](
            xyz, idx, ideal, sigs, grad_out, dxyz,
            N=N, BLOCK=BLOCK,
        )
        return dxyz, None, None, None


[docs] def chiral_math_triton(xyz, indices, ideal_volumes, sigmas): """Triton-backed chiral-volume Gaussian NLL with analytic backward.""" return _ChiralMathTriton.apply(xyz, indices, ideal_volumes, sigmas)