torchref.base.targets.planarity module
Planarity restraint NLL.
- torchref.base.targets.planarity.planarity_math(xyz, plane_groups)[source]
Planarity NLL summed over plane-size buckets.
Matches
PlanarityTarget.forward. Each entry ofplane_groupsis(indices, sigmas)whereindiceshas shape(P, n_atoms)withn_atoms > 3(3-atom planes have zero deviation by construction and are skipped by the caller).The plane normal is computed in float64 via SVD and detached — gradients flow through the deviation projection but not through the eigendecomposition.
Dispatches to
torchref.base.targets.triton.planarity.planarity_math_triton()on CUDA fp32 (per-bucket fused gather/centroid/project/NLL kernel + analytic backward). Falls back to eager otherwise.