torchref.base.targets.triton.planarity module
Triton forward + analytic backward for the planarity target.
The plane normals are computed on the host via a detached SVD (float64,
matching the eager path) — that step is not Triton-able and already
runs without autograd. The Triton kernel handles the per-plane gather
+ centroid + (pos - centroid)·normal + Gaussian NLL + sum, and the
backward kernel scatters the analytic gradient back to xyz.
Backward derivation. NLL_p = sum_a 0.5 (d_pa / σ_p)^2 + n (log σ_p + 0.5 log 2π) where d_pa = (pos_pa - centroid_p) · n_p. Because centroid_p depends on every atom in the plane via the mean, the gradient w.r.t. each member atom c is
∂NLL_p/∂pos_pc = [d_pc / σ_p² - mean_a(d_pa / σ_p²)] · n_p .
- torchref.base.targets.triton.planarity.planarity_math_triton(xyz, plane_groups)[source]
Triton-backed planarity NLL with analytic backward.
Drop-in replacement for
torchref.base.targets.planarity.planarity_math(). SVD-derived plane normals are computed on the host (detached, same as eager); the gather + project + NLL + sum and the gradient scatter run in Triton.