"""Bond-angle restraint NLL."""
import torch
from ._common import LOG_2PI
from ._dispatch import use_triton
def _angle_math_eager(
xyz: torch.Tensor,
idx: torch.Tensor,
references_rad: torch.Tensor,
sigmas_rad: torch.Tensor,
) -> torch.Tensor:
pos_a = xyz[idx[:, 0]]
pos_b = xyz[idx[:, 1]]
pos_c = xyz[idx[:, 2]]
v1 = pos_a - pos_b
v2 = pos_c - pos_b
cos_angle = torch.clamp(
torch.sum(v1 * v2, dim=-1)
/ (torch.linalg.norm(v1, dim=-1) * torch.linalg.norm(v2, dim=-1)),
-1.0, 1.0,
)
angles_rad = torch.acos(cos_angle)
deviations = angles_rad - references_rad
nll = 0.5 * (deviations / sigmas_rad) ** 2 + torch.log(sigmas_rad) + 0.5 * LOG_2PI
return nll.sum()
[docs]
def angle_math(
xyz: torch.Tensor,
idx: torch.Tensor,
references_rad: torch.Tensor,
sigmas_rad: torch.Tensor,
) -> torch.Tensor:
"""Angle NLL: gather, compute angle, Gaussian NLL, sum.
Dispatches to :func:`torchref.base.targets.triton.angle_math_triton`
on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager
otherwise.
Parameters
----------
xyz : torch.Tensor
(N_atoms, 3) Cartesian coordinates.
idx : torch.Tensor
(N, 3) integer indices [a, b, c] with ``b`` as the vertex.
references_rad : torch.Tensor
(N,) target angles in radians.
sigmas_rad : torch.Tensor
(N,) standard deviations in radians.
"""
if use_triton(xyz):
from .triton.angle import angle_math_triton
return angle_math_triton(xyz, idx, references_rad, sigmas_rad)
return _angle_math_eager(xyz, idx, references_rad, sigmas_rad)