Source code for torchref.base.targets.torsion

"""Torsion restraint NLL: unimodal von Mises + omega cis/trans mixture."""

import torch

from ._common import DEG2RAD, LOG_2PI, torsions_from_xyz
from ._dispatch import use_triton


[docs] def torsion_unimodal_math( deviations_rad: torch.Tensor, sigmas_deg: torch.Tensor, ) -> torch.Tensor: """Unimodal von Mises NLL on already-wrapped deviations. Mirrors ``targets.geometry.torsions._von_mises_nll``. The caller is expected to compute wrapped angular deviations beforehand (currently done inside ``Restraints.torsion_deviations_with_sigmas``). No Triton dispatch yet — the periodic-wrap logic upstream is still in eager Python. """ sigmas_rad = sigmas_deg * DEG2RAD kappa = torch.clamp(1.0 / (sigmas_rad ** 2), min=1e-3, max=1e4) log_i0_kappa = torch.log(torch.special.i0e(kappa)) + kappa log_prob = kappa * torch.cos(deviations_rad) - log_i0_kappa - LOG_2PI return -log_prob.sum()
def _torsion_omega_math_eager( xyz: torch.Tensor, idx: torch.Tensor, sigmas_deg: torch.Tensor, is_proline: torch.Tensor, w_cis_proline: float = 0.05, w_cis_general: float = 0.0005, ) -> torch.Tensor: omega_deg = torsions_from_xyz(xyz, idx) omega_rad = omega_deg * DEG2RAD sigmas_rad = sigmas_deg * DEG2RAD kappa = torch.clamp(1.0 / (sigmas_rad ** 2), min=1e-3, max=1e4) log_i0_kappa = torch.log(torch.special.i0e(kappa)) + kappa log_norm = LOG_2PI + log_i0_kappa w_cis = torch.where( is_proline, torch.tensor(w_cis_proline, device=kappa.device, dtype=kappa.dtype), torch.tensor(w_cis_general, device=kappa.device, dtype=kappa.dtype), ) w_trans = 1.0 - w_cis cos_omega = torch.cos(omega_rad) log_p_trans = torch.log(w_trans) - kappa * cos_omega log_p_cis = torch.log(w_cis) + kappa * cos_omega log_mixture = torch.logsumexp(torch.stack([log_p_trans, log_p_cis]), dim=0) return (log_norm - log_mixture).sum()
[docs] def torsion_omega_math( xyz: torch.Tensor, idx: torch.Tensor, sigmas_deg: torch.Tensor, is_proline: torch.Tensor, w_cis_proline: float = 0.05, w_cis_general: float = 0.0005, ) -> torch.Tensor: """Omega cis/trans mixture NLL. Mirrors ``targets.geometry.torsions._omega_mixture_nll`` plus the omega-angle computation. Models each ω as a 2-component mixture: ``w_trans VM(ω; π, κ) + w_cis VM(ω; 0, κ)``. Dispatches to :func:`torchref.base.targets.triton.torsion_omega_math_triton` on CUDA float32 (~5.6× faster fwd+bw on A100). Falls back to eager otherwise. """ if use_triton(xyz): from .triton.torsion import torsion_omega_math_triton return torsion_omega_math_triton( xyz, idx, sigmas_deg, is_proline, w_cis_proline, w_cis_general, ) return _torsion_omega_math_eager( xyz, idx, sigmas_deg, is_proline, w_cis_proline, w_cis_general, )