torchref.base.targets.planarity module

Planarity restraint NLL.

torchref.base.targets.planarity.planarity_math(xyz, plane_groups)[source]

Planarity NLL summed over plane-size buckets.

Matches PlanarityTarget.forward. Each entry of plane_groups is (indices, sigmas) where indices has shape (P, n_atoms) with n_atoms > 3 (3-atom planes have zero deviation by construction and are skipped by the caller).

The plane normal is computed in float64 via SVD and detached — gradients flow through the deviation projection but not through the eigendecomposition.

Dispatches to torchref.base.targets.triton.planarity.planarity_math_triton() on CUDA fp32 (per-bucket fused gather/centroid/project/NLL kernel + analytic backward). Falls back to eager otherwise.