torchref.base.targets.angle module
Bond-angle restraint NLL.
- torchref.base.targets.angle.angle_math(xyz, idx, references_rad, sigmas_rad)[source]
Angle NLL: gather, compute angle, Gaussian NLL, sum.
Dispatches to
torchref.base.targets.triton.angle_math_triton()on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.- Parameters:
xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.
idx (torch.Tensor) – (N, 3) integer indices [a, b, c] with
bas the vertex.references_rad (torch.Tensor) – (N,) target angles in radians.
sigmas_rad (torch.Tensor) – (N,) standard deviations in radians.