torchref.base.targets.angle module

Bond-angle restraint NLL.

torchref.base.targets.angle.angle_math(xyz, idx, references_rad, sigmas_rad)[source]

Angle NLL: gather, compute angle, Gaussian NLL, sum.

Dispatches to torchref.base.targets.triton.angle_math_triton() on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.

Parameters:
  • xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.

  • idx (torch.Tensor) – (N, 3) integer indices [a, b, c] with b as the vertex.

  • references_rad (torch.Tensor) – (N,) target angles in radians.

  • sigmas_rad (torch.Tensor) – (N,) standard deviations in radians.