"""Triton forward + analytic backward for the planarity target.
The plane normals are computed on the host via a detached SVD (float64,
matching the eager path) — that step is not Triton-able and already
runs without autograd. The Triton kernel handles the per-plane gather
+ centroid + (pos - centroid)·normal + Gaussian NLL + sum, and the
backward kernel scatters the analytic gradient back to ``xyz``.
Backward derivation. NLL_p = sum_a 0.5 (d_pa / σ_p)^2 + n (log σ_p + 0.5 log 2π)
where d_pa = (pos_pa - centroid_p) · n_p. Because centroid_p depends on
every atom in the plane via the mean, the gradient w.r.t. each member
atom c is
∂NLL_p/∂pos_pc = [d_pc / σ_p² - mean_a(d_pa / σ_p²)] · n_p .
"""
from __future__ import annotations
import math
from typing import List, Tuple
import torch
import triton
import triton.language as tl
_LOG_2PI = float(math.log(2.0 * math.pi))
@triton.jit
def _plan_nll_fwd_kernel(
xyz_ptr,
indices_ptr, # (P, N_atoms) int
normals_ptr, # (P, 3) float
sigmas_ptr, # (P,) float
out_ptr, # (P,) per-plane NLL
P: tl.constexpr,
N_ATOMS: tl.constexpr,
LOG_2PI: tl.constexpr,
BLOCK_P: tl.constexpr,
):
pid = tl.program_id(0)
p_offs = pid * BLOCK_P + tl.arange(0, BLOCK_P)
mask = p_offs < P
nx = tl.load(normals_ptr + p_offs * 3 + 0, mask=mask, other=0.0)
ny = tl.load(normals_ptr + p_offs * 3 + 1, mask=mask, other=0.0)
nz = tl.load(normals_ptr + p_offs * 3 + 2, mask=mask, other=0.0)
sig = tl.load(sigmas_ptr + p_offs, mask=mask, other=1.0)
inv_sig = 1.0 / sig
# 1st pass: centroid
cx = tl.zeros_like(sig)
cy = tl.zeros_like(sig)
cz = tl.zeros_like(sig)
inv_N = 1.0 / N_ATOMS
for a in tl.static_range(N_ATOMS):
i = tl.load(indices_ptr + p_offs * N_ATOMS + a, mask=mask, other=0)
cx += tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0) * inv_N
cy += tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0) * inv_N
cz += tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0) * inv_N
# 2nd pass: sum of NLLs
nll = tl.zeros_like(sig)
per_atom_const = tl.log(sig) + 0.5 * LOG_2PI
for a in tl.static_range(N_ATOMS):
i = tl.load(indices_ptr + p_offs * N_ATOMS + a, mask=mask, other=0)
x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
dev = (x - cx) * nx + (y - cy) * ny + (z - cz) * nz
nll += 0.5 * (dev * inv_sig) * (dev * inv_sig) + per_atom_const
tl.store(out_ptr + p_offs, nll, mask=mask)
@triton.jit
def _plan_nll_bwd_kernel(
xyz_ptr,
indices_ptr,
normals_ptr,
sigmas_ptr,
grad_out_ptr, # 0-D tensor — loaded in-kernel (no host .item() sync)
dxyz_ptr,
P: tl.constexpr,
N_ATOMS: tl.constexpr,
BLOCK_P: tl.constexpr,
):
pid = tl.program_id(0)
p_offs = pid * BLOCK_P + tl.arange(0, BLOCK_P)
mask = p_offs < P
grad_out = tl.load(grad_out_ptr)
nx = tl.load(normals_ptr + p_offs * 3 + 0, mask=mask, other=0.0)
ny = tl.load(normals_ptr + p_offs * 3 + 1, mask=mask, other=0.0)
nz = tl.load(normals_ptr + p_offs * 3 + 2, mask=mask, other=0.0)
sig = tl.load(sigmas_ptr + p_offs, mask=mask, other=1.0)
inv_sig2 = 1.0 / (sig * sig)
inv_N = 1.0 / N_ATOMS
# 1st pass: centroid
cx = tl.zeros_like(sig); cy = tl.zeros_like(sig); cz = tl.zeros_like(sig)
for a in tl.static_range(N_ATOMS):
i = tl.load(indices_ptr + p_offs * N_ATOMS + a, mask=mask, other=0)
cx += tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0) * inv_N
cy += tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0) * inv_N
cz += tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0) * inv_N
# 2nd pass: mean(dev / σ²) over atoms in plane
mean_dso = tl.zeros_like(sig)
for a in tl.static_range(N_ATOMS):
i = tl.load(indices_ptr + p_offs * N_ATOMS + a, mask=mask, other=0)
x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
dev = (x - cx) * nx + (y - cy) * ny + (z - cz) * nz
mean_dso += dev * inv_sig2 * inv_N
# 3rd pass: scatter gradients
for a in tl.static_range(N_ATOMS):
i = tl.load(indices_ptr + p_offs * N_ATOMS + a, mask=mask, other=0)
x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
dev = (x - cx) * nx + (y - cy) * ny + (z - cz) * nz
coef = grad_out * (dev * inv_sig2 - mean_dso)
tl.atomic_add(dxyz_ptr + i * 3 + 0, coef * nx, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 1, coef * ny, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 2, coef * nz, mask=mask)
def _plane_normals_via_eigh(
covariances: torch.Tensor,
) -> torch.Tensor:
"""Smallest-eigenvalue eigenvector of a batch of 3×3 SPD covariances.
For an (n_atoms, 3) centered matrix, the right singular vector with
smallest singular value coincides with the eigenvector of smallest
eigenvalue of ``centeredᵀ centered``. Working on the (P, 3, 3)
covariance and using ``eigh`` (faster than ``svd`` for symmetric
inputs, batches well across plane-size buckets) cuts the
plane-normal cost ~5× on A100 vs the per-bucket fp64 SVD on the
full ``centered`` matrix.
Promoted to fp64 for the eigh itself, matching the eager helper's
precision regime — important near collinear-atom degeneracies
where the smallest-eigvec lives in a 2-D ambiguous subspace and
the chosen direction is implementation-defined. The fp64 cost
is still negligible relative to the SVD it replaces.
Robustness — LBFGS line search occasionally probes wild trial steps
that can drive ``xyz`` (and hence the covariance) to NaN / Inf. In
that regime ``eigh`` raises ``torch._C._LinAlgError`` whereas SVD
on the original ``centered`` returns a finite (arbitrary) vector.
To keep the trial step well-defined we catch the LinAlgError and
return zero normals; the resulting deviations are zero, the NLL
contribution is finite (``per-atom-const · n_atoms``), the gradient
w.r.t. atom positions is exactly zero (the trial step gets a
finite loss with no pull from planarity → validate_loss /
line-search reject it on its own). Bit-perfect collinearity, by
contrast, doesn't trigger this path — both ``eigh`` and ``svd``
return a finite unit vector in the ambiguous subspace.
"""
in_dtype = covariances.dtype
try:
_vals, vecs = torch.linalg.eigh(covariances.to(torch.float64))
normals = vecs[..., 0].to(in_dtype)
except torch._C._LinAlgError:
# Pathological input (NaN / Inf / non-convergent). Return zero
# normals — finite NLL, zero gradient, trial step rejected.
normals = torch.zeros(
covariances.shape[0], 3, dtype=in_dtype, device=covariances.device,
)
return normals
# Even when eigh succeeds it can emit NaN rows if a per-batch input
# was finite but contained subnormals or extreme range. Same fix:
# zero-out non-finite rows. We do the ``torch.where`` unconditionally
# — skipping the Python ``bool(...item())`` guard so this function
# is CUDA-Graph-capture-safe. The where is a single elementwise
# kernel that's a no-op on the (overwhelmingly common) all-finite
# path; cheaper than the previous host sync anyway.
finite = torch.isfinite(normals).all(dim=-1, keepdim=True)
normals = torch.where(finite, normals, torch.zeros_like(normals))
return normals
class _PlanarityMathTriton(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz, plane_groups):
assert xyz.is_cuda and xyz.dtype == torch.float32
# Build all per-bucket covariances, stack into one tensor, and
# run a single batched eigh. Drops the SVD cost from 5.4 ms
# to ~0.7 ms on 1DAW / A100.
bucket_normals = []
bucket_outs = []
with torch.no_grad():
covs = []
centered_per_bucket = []
for (indices, _sigmas) in plane_groups:
positions = xyz[indices] # (P, n, 3)
centroids = positions.mean(dim=1, keepdim=True)
centered = positions - centroids
covs.append(centered.transpose(-1, -2) @ centered) # (P, 3, 3)
centered_per_bucket.append(centered)
all_covs = torch.cat(covs, dim=0) # (Σ P, 3, 3)
all_normals = _plane_normals_via_eigh(all_covs).to(xyz.dtype)
# Split back to per-bucket
offset = 0
for (_indices, _sigmas), cov in zip(plane_groups, covs):
P = cov.shape[0]
bucket_normals.append(all_normals[offset:offset + P].contiguous())
offset += P
for (indices, sigmas), normals in zip(plane_groups, bucket_normals):
P = indices.shape[0]
N_atoms = indices.shape[1]
out = torch.empty(P, dtype=xyz.dtype, device=xyz.device)
BLOCK_P = 64
grid = (triton.cdiv(P, BLOCK_P),)
_plan_nll_fwd_kernel[grid](
xyz, indices, normals, sigmas, out,
P=P, N_ATOMS=int(N_atoms),
LOG_2PI=_LOG_2PI, BLOCK_P=BLOCK_P,
)
bucket_outs.append(out)
ctx.save_for_backward(xyz, *[t for pair in plane_groups for t in pair],
*bucket_normals)
ctx.plane_groups = plane_groups
ctx.bucket_normals = bucket_normals
if not bucket_outs:
return torch.zeros((), device=xyz.device, dtype=xyz.dtype)
return torch.cat(bucket_outs).sum()
@staticmethod
def backward(ctx, grad_out):
xyz = ctx.saved_tensors[0]
dxyz = torch.zeros_like(xyz)
# grad_out is a 0-D device tensor; passed by pointer.
for (indices, sigmas), normals in zip(
ctx.plane_groups, ctx.bucket_normals
):
P = indices.shape[0]
N_atoms = indices.shape[1]
BLOCK_P = 64
grid = (triton.cdiv(P, BLOCK_P),)
_plan_nll_bwd_kernel[grid](
xyz, indices, normals, sigmas, grad_out, dxyz,
P=P, N_ATOMS=int(N_atoms), BLOCK_P=BLOCK_P,
)
return dxyz, None
[docs]
def planarity_math_triton(
xyz: torch.Tensor,
plane_groups: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Triton-backed planarity NLL with analytic backward.
Drop-in replacement for
:func:`torchref.base.targets.planarity.planarity_math`. SVD-derived
plane normals are computed on the host (detached, same as eager);
the gather + project + NLL + sum and the gradient scatter run in
Triton.
"""
if not plane_groups:
return torch.zeros((), device=xyz.device, dtype=xyz.dtype)
return _PlanarityMathTriton.apply(xyz, plane_groups)