Source code for torchref.base.targets.planarity

"""Planarity restraint NLL."""

from typing import List, Tuple

import torch

from ._common import LOG_2PI
from ._dispatch import use_triton


def _plane_normals_detached(centered: torch.Tensor) -> torch.Tensor:
    """SVD-derived plane normal (right singular vector at smallest σ).

    Caller must wrap in ``torch.no_grad()``; the result is intentionally
    detached so backward flows only through the deviation projection.

    SVD is run in the input dtype — for small (P, N, 3) matrices over
    O(Å) atom coordinates, float32 is numerically sufficient.
    """
    _U, _S, Vh = torch.linalg.svd(centered, full_matrices=False)
    return Vh[:, -1, :]


def _planarity_math_eager(
    xyz: torch.Tensor,
    plane_groups: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
    if not plane_groups:
        return torch.zeros((), device=xyz.device, dtype=xyz.dtype)
    all_nlls = []
    for indices, sigmas in plane_groups:
        positions = xyz[indices]
        centroids = positions.mean(dim=1, keepdim=True)
        centered = positions - centroids
        with torch.no_grad():
            normals = _plane_normals_detached(centered.detach())
        deviations = torch.einsum("paj,pj->pa", centered, normals)
        nll = 0.5 * (deviations / sigmas) ** 2 + torch.log(sigmas) + 0.5 * LOG_2PI
        all_nlls.append(nll.flatten())
    return torch.cat(all_nlls).sum()


[docs] def planarity_math( xyz: torch.Tensor, plane_groups: List[Tuple[torch.Tensor, torch.Tensor]], ) -> torch.Tensor: """Planarity NLL summed over plane-size buckets. Matches ``PlanarityTarget.forward``. Each entry of ``plane_groups`` is ``(indices, sigmas)`` where ``indices`` has shape ``(P, n_atoms)`` with ``n_atoms > 3`` (3-atom planes have zero deviation by construction and are skipped by the caller). The plane normal is computed in float64 via SVD and detached — gradients flow through the deviation projection but not through the eigendecomposition. Dispatches to :func:`torchref.base.targets.triton.planarity.planarity_math_triton` on CUDA fp32 (per-bucket fused gather/centroid/project/NLL kernel + analytic backward). Falls back to eager otherwise. """ if use_triton(xyz): from .triton.planarity import planarity_math_triton return planarity_math_triton(xyz, plane_groups) return _planarity_math_eager(xyz, plane_groups)