torchref.base.targets.triton.planarity module

Triton forward + analytic backward for the planarity target.

The plane normals are computed on the host via a detached SVD (float64, matching the eager path) — that step is not Triton-able and already runs without autograd. The Triton kernel handles the per-plane gather + centroid + (pos - centroid)·normal + Gaussian NLL + sum, and the backward kernel scatters the analytic gradient back to xyz.

Backward derivation. NLL_p = sum_a 0.5 (d_pa / σ_p)^2 + n (log σ_p + 0.5 log 2π) where d_pa = (pos_pa - centroid_p) · n_p. Because centroid_p depends on every atom in the plane via the mean, the gradient w.r.t. each member atom c is

∂NLL_p/∂pos_pc = [d_pc / σ_p² - mean_a(d_pa / σ_p²)] · n_p .

torchref.base.targets.triton.planarity.planarity_math_triton(xyz, plane_groups)[source]

Triton-backed planarity NLL with analytic backward.

Drop-in replacement for torchref.base.targets.planarity.planarity_math(). SVD-derived plane normals are computed on the host (detached, same as eager); the gather + project + NLL + sum and the gradient scatter run in Triton.