torchref.base.targets.chiral module

Chiral-volume restraint NLL.

torchref.base.targets.chiral.chiral_math(xyz, indices, ideal_volumes, sigmas)[source]

Chiral volume NLL (matches ChiralTarget.forward).

For each chiral center, computes the signed tetrahedral volume V = v1 . (v2 x v3) where vi = xyz[i] - xyz[center]. Achiral centers (ideal_volumes == 0) are restrained on |V| against 2.5.

Dispatches to torchref.base.targets.triton.chiral_math_triton() on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.

Parameters:
  • xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.

  • indices (torch.Tensor) – (N, 4) integer indices [center, a1, a2, a3].

  • ideal_volumes (torch.Tensor) – (N,) target signed volumes.

  • sigmas (torch.Tensor) – (N,) standard deviations.