"""Triton forward + analytic backward for the omega cis/trans torsion target."""
from __future__ import annotations
import math
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice
from ..torsion import torsion_omega_math as _omega_eager
from ._dihedral import dihedral_and_grad
_LOG_2PI = float(math.log(2.0 * math.pi))
_DEG2RAD = float(math.pi / 180.0)
@triton.jit
def _omega_nll_fwd_kernel(
xyz_ptr,
idx_ptr, # (N, 4)
sig_deg_ptr, # (N,)
is_proline_ptr, # (N,) bool/i1
out_ptr, # (N,)
w_cis_proline: tl.constexpr,
w_cis_general: tl.constexpr,
N: tl.constexpr,
LOG_2PI: tl.constexpr,
DEG2RAD: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
a = tl.load(idx_ptr + offs * 4 + 0, mask=mask, other=0)
b = tl.load(idx_ptr + offs * 4 + 1, mask=mask, other=0)
c = tl.load(idx_ptr + offs * 4 + 2, mask=mask, other=0)
d = tl.load(idx_ptr + offs * 4 + 3, mask=mask, other=0)
p1x = tl.load(xyz_ptr + a * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + a * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + a * 3 + 2, mask=mask, other=0.0)
p2x = tl.load(xyz_ptr + b * 3 + 0, mask=mask, other=0.0)
p2y = tl.load(xyz_ptr + b * 3 + 1, mask=mask, other=0.0)
p2z = tl.load(xyz_ptr + b * 3 + 2, mask=mask, other=0.0)
p3x = tl.load(xyz_ptr + c * 3 + 0, mask=mask, other=0.0)
p3y = tl.load(xyz_ptr + c * 3 + 1, mask=mask, other=0.0)
p3z = tl.load(xyz_ptr + c * 3 + 2, mask=mask, other=0.0)
p4x = tl.load(xyz_ptr + d * 3 + 0, mask=mask, other=0.0)
p4y = tl.load(xyz_ptr + d * 3 + 1, mask=mask, other=0.0)
p4z = tl.load(xyz_ptr + d * 3 + 2, mask=mask, other=0.0)
(omega, _F1x, _F1y, _F1z, _F2x, _F2y, _F2z,
_F3x, _F3y, _F3z, _F4x, _F4y, _F4z) = dihedral_and_grad(
p1x, p1y, p1z, p2x, p2y, p2z, p3x, p3y, p3z, p4x, p4y, p4z,
)
sig_deg = tl.load(sig_deg_ptr + offs, mask=mask, other=1.0)
sig_rad = sig_deg * DEG2RAD
kappa = 1.0 / (sig_rad * sig_rad)
kappa = tl.minimum(tl.maximum(kappa, 1e-3), 1e4)
# log(I_0(κ)) — Triton libdevice has cyl_bessel_i0 but not i0e. Use
# the direct value for moderate κ and the higher-order asymptotic for
# large κ to avoid overflow. The 1/(8κ) term keeps the asymptotic
# within float32 noise of the exact log I_0 at the switch point.
asym = (
kappa
- 0.5 * tl.log(2.0 * 3.141592653589793 * kappa)
+ 1.0 / (8.0 * kappa)
)
log_i0_kappa = tl.where(
kappa < 30.0,
tl.log(libdevice.cyl_bessel_i0(kappa)),
asym,
)
log_norm = LOG_2PI + log_i0_kappa
is_pro = tl.load(is_proline_ptr + offs, mask=mask, other=0).to(tl.int1)
w_cis = tl.where(is_pro, w_cis_proline, w_cis_general)
w_trans = 1.0 - w_cis
cos_o = tl.cos(omega)
log_p_trans = tl.log(w_trans) - kappa * cos_o
log_p_cis = tl.log(w_cis) + kappa * cos_o
m = tl.maximum(log_p_trans, log_p_cis)
log_mixture = m + tl.log(tl.exp(log_p_trans - m) + tl.exp(log_p_cis - m))
nll = log_norm - log_mixture
tl.store(out_ptr + offs, nll, mask=mask)
@triton.jit
def _omega_nll_bwd_kernel(
xyz_ptr,
idx_ptr,
sig_deg_ptr,
is_proline_ptr,
grad_out_ptr, # 0-D tensor — loaded in-kernel (no host .item() sync)
dxyz_ptr,
w_cis_proline: tl.constexpr,
w_cis_general: tl.constexpr,
N: tl.constexpr,
DEG2RAD: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
grad_out = tl.load(grad_out_ptr)
a = tl.load(idx_ptr + offs * 4 + 0, mask=mask, other=0)
b = tl.load(idx_ptr + offs * 4 + 1, mask=mask, other=0)
c = tl.load(idx_ptr + offs * 4 + 2, mask=mask, other=0)
d = tl.load(idx_ptr + offs * 4 + 3, mask=mask, other=0)
p1x = tl.load(xyz_ptr + a * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + a * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + a * 3 + 2, mask=mask, other=0.0)
p2x = tl.load(xyz_ptr + b * 3 + 0, mask=mask, other=0.0)
p2y = tl.load(xyz_ptr + b * 3 + 1, mask=mask, other=0.0)
p2z = tl.load(xyz_ptr + b * 3 + 2, mask=mask, other=0.0)
p3x = tl.load(xyz_ptr + c * 3 + 0, mask=mask, other=0.0)
p3y = tl.load(xyz_ptr + c * 3 + 1, mask=mask, other=0.0)
p3z = tl.load(xyz_ptr + c * 3 + 2, mask=mask, other=0.0)
p4x = tl.load(xyz_ptr + d * 3 + 0, mask=mask, other=0.0)
p4y = tl.load(xyz_ptr + d * 3 + 1, mask=mask, other=0.0)
p4z = tl.load(xyz_ptr + d * 3 + 2, mask=mask, other=0.0)
(omega, F1x, F1y, F1z, F2x, F2y, F2z,
F3x, F3y, F3z, F4x, F4y, F4z) = dihedral_and_grad(
p1x, p1y, p1z, p2x, p2y, p2z, p3x, p3y, p3z, p4x, p4y, p4z,
)
sig_deg = tl.load(sig_deg_ptr + offs, mask=mask, other=1.0)
sig_rad = sig_deg * DEG2RAD
kappa = 1.0 / (sig_rad * sig_rad)
kappa = tl.minimum(tl.maximum(kappa, 1e-3), 1e4)
is_pro = tl.load(is_proline_ptr + offs, mask=mask, other=0).to(tl.int1)
w_cis = tl.where(is_pro, w_cis_proline, w_cis_general)
w_trans = 1.0 - w_cis
cos_o = tl.cos(omega)
sin_o = tl.sin(omega)
log_p_trans = tl.log(w_trans) - kappa * cos_o
log_p_cis = tl.log(w_cis) + kappa * cos_o
m = tl.maximum(log_p_trans, log_p_cis)
log_mix = m + tl.log(tl.exp(log_p_trans - m) + tl.exp(log_p_cis - m))
s_t = tl.exp(log_p_trans - log_mix)
s_c = tl.exp(log_p_cis - log_mix)
# dNLL/dω = κ sin(ω) (s_c − s_t)
coef = grad_out * kappa * sin_o * (s_c - s_t)
g1x = coef * F1x; g1y = coef * F1y; g1z = coef * F1z
g2x = coef * F2x; g2y = coef * F2y; g2z = coef * F2z
g3x = coef * F3x; g3y = coef * F3y; g3z = coef * F3z
g4x = coef * F4x; g4y = coef * F4y; g4z = coef * F4z
tl.atomic_add(dxyz_ptr + a * 3 + 0, g1x, mask=mask)
tl.atomic_add(dxyz_ptr + a * 3 + 1, g1y, mask=mask)
tl.atomic_add(dxyz_ptr + a * 3 + 2, g1z, mask=mask)
tl.atomic_add(dxyz_ptr + b * 3 + 0, g2x, mask=mask)
tl.atomic_add(dxyz_ptr + b * 3 + 1, g2y, mask=mask)
tl.atomic_add(dxyz_ptr + b * 3 + 2, g2z, mask=mask)
tl.atomic_add(dxyz_ptr + c * 3 + 0, g3x, mask=mask)
tl.atomic_add(dxyz_ptr + c * 3 + 1, g3y, mask=mask)
tl.atomic_add(dxyz_ptr + c * 3 + 2, g3z, mask=mask)
tl.atomic_add(dxyz_ptr + d * 3 + 0, g4x, mask=mask)
tl.atomic_add(dxyz_ptr + d * 3 + 1, g4y, mask=mask)
tl.atomic_add(dxyz_ptr + d * 3 + 2, g4z, mask=mask)
class _TorsionOmegaMathTriton(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz, idx, sigmas_deg, is_proline,
w_cis_proline: float, w_cis_general: float):
assert xyz.is_cuda and xyz.dtype == torch.float32
N = idx.shape[0]
# is_proline may be bool; the kernel loads as int1 via .to(tl.int1)
is_pro_u8 = is_proline.to(torch.uint8).contiguous()
nll = torch.empty(N, dtype=xyz.dtype, device=xyz.device)
BLOCK = 128
grid = (triton.cdiv(N, BLOCK),)
_omega_nll_fwd_kernel[grid](
xyz, idx, sigmas_deg, is_pro_u8, nll,
w_cis_proline=float(w_cis_proline),
w_cis_general=float(w_cis_general),
N=N, LOG_2PI=_LOG_2PI, DEG2RAD=_DEG2RAD, BLOCK=BLOCK,
)
ctx.save_for_backward(xyz, idx, sigmas_deg, is_pro_u8)
ctx.w_cis_proline = float(w_cis_proline)
ctx.w_cis_general = float(w_cis_general)
return nll.sum()
@staticmethod
def backward(ctx, grad_out):
xyz, idx, sigs, is_pro_u8 = ctx.saved_tensors
N = idx.shape[0]
dxyz = torch.zeros_like(xyz)
BLOCK = 128
grid = (triton.cdiv(N, BLOCK),)
_omega_nll_bwd_kernel[grid](
xyz, idx, sigs, is_pro_u8, grad_out, dxyz,
w_cis_proline=ctx.w_cis_proline,
w_cis_general=ctx.w_cis_general,
N=N, DEG2RAD=_DEG2RAD, BLOCK=BLOCK,
)
return dxyz, None, None, None, None, None
[docs]
def torsion_omega_math_triton(xyz, idx, sigmas_deg, is_proline,
w_cis_proline=0.05, w_cis_general=0.0005):
"""Triton-backed omega cis/trans mixture NLL with analytic backward."""
return _TorsionOmegaMathTriton.apply(
xyz, idx, sigmas_deg, is_proline,
float(w_cis_proline), float(w_cis_general),
)
# ---------------------------------------------------------------------------
# Unimodal torsion (intra-residue + disulfide): dihedral + periodic wrap +
# von Mises NLL. The wrap handles arbitrary n-fold rotational symmetry by
# picking the equivalent angle with smallest absolute deviation.
# ---------------------------------------------------------------------------
_TWO_PI = float(2.0 * math.pi)
_PI = float(math.pi)
@triton.jit
def _torsion_uni_fwd_kernel(
xyz_ptr,
idx_ptr, # (N, 4)
ref_deg_ptr, # (N,)
sig_deg_ptr, # (N,)
period_ptr, # (N,) int (clamped >= 1)
out_ptr, # (N,)
N: tl.constexpr,
MAX_PERIOD: tl.constexpr,
LOG_2PI: tl.constexpr,
DEG2RAD: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
a = tl.load(idx_ptr + offs * 4 + 0, mask=mask, other=0)
b = tl.load(idx_ptr + offs * 4 + 1, mask=mask, other=0)
c = tl.load(idx_ptr + offs * 4 + 2, mask=mask, other=0)
d = tl.load(idx_ptr + offs * 4 + 3, mask=mask, other=0)
p1x = tl.load(xyz_ptr + a * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + a * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + a * 3 + 2, mask=mask, other=0.0)
p2x = tl.load(xyz_ptr + b * 3 + 0, mask=mask, other=0.0)
p2y = tl.load(xyz_ptr + b * 3 + 1, mask=mask, other=0.0)
p2z = tl.load(xyz_ptr + b * 3 + 2, mask=mask, other=0.0)
p3x = tl.load(xyz_ptr + c * 3 + 0, mask=mask, other=0.0)
p3y = tl.load(xyz_ptr + c * 3 + 1, mask=mask, other=0.0)
p3z = tl.load(xyz_ptr + c * 3 + 2, mask=mask, other=0.0)
p4x = tl.load(xyz_ptr + d * 3 + 0, mask=mask, other=0.0)
p4y = tl.load(xyz_ptr + d * 3 + 1, mask=mask, other=0.0)
p4z = tl.load(xyz_ptr + d * 3 + 2, mask=mask, other=0.0)
(omega, _F1x, _F1y, _F1z, _F2x, _F2y, _F2z,
_F3x, _F3y, _F3z, _F4x, _F4y, _F4z) = dihedral_and_grad(
p1x, p1y, p1z, p2x, p2y, p2z, p3x, p3y, p3z, p4x, p4y, p4z,
)
# omega is in radians (atan2 output)
ref_rad = tl.load(ref_deg_ptr + offs, mask=mask, other=0.0) * DEG2RAD
diff = omega - ref_rad
period = tl.load(period_ptr + offs, mask=mask, other=1)
period_f = period.to(tl.float32)
step = (2.0 * 3.141592653589793) / period_f
# Pick equivalent k * step + diff with smallest |wrap(.)| over k in [0, period-1].
LARGE = 1e30
best_dev = tl.zeros_like(omega)
best_abs = tl.full(omega.shape, LARGE, tl.float32)
for k in tl.static_range(MAX_PERIOD):
valid = k < period
cand = diff + k * step
# wrap to [-π, π]
cand = cand - (2.0 * 3.141592653589793) * libdevice.round(cand / (2.0 * 3.141592653589793))
a_cand = tl.abs(cand)
is_better = valid & (a_cand < best_abs)
best_dev = tl.where(is_better, cand, best_dev)
best_abs = tl.where(is_better, a_cand, best_abs)
sig_deg = tl.load(sig_deg_ptr + offs, mask=mask, other=1.0)
sig_rad = sig_deg * DEG2RAD
kappa = 1.0 / (sig_rad * sig_rad)
kappa = tl.minimum(tl.maximum(kappa, 1e-3), 1e4)
asym = (
kappa - 0.5 * tl.log(2.0 * 3.141592653589793 * kappa)
+ 1.0 / (8.0 * kappa)
)
log_i0_kappa = tl.where(
kappa < 30.0,
tl.log(libdevice.cyl_bessel_i0(kappa)),
asym,
)
log_prob = kappa * tl.cos(best_dev) - log_i0_kappa - LOG_2PI
nll = -log_prob
tl.store(out_ptr + offs, nll, mask=mask)
@triton.jit
def _torsion_uni_bwd_kernel(
xyz_ptr,
idx_ptr,
ref_deg_ptr,
sig_deg_ptr,
period_ptr,
grad_out_ptr, # 0-D tensor — loaded in-kernel (no host .item() sync)
dxyz_ptr,
N: tl.constexpr,
MAX_PERIOD: tl.constexpr,
DEG2RAD: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
grad_out = tl.load(grad_out_ptr)
a = tl.load(idx_ptr + offs * 4 + 0, mask=mask, other=0)
b = tl.load(idx_ptr + offs * 4 + 1, mask=mask, other=0)
c = tl.load(idx_ptr + offs * 4 + 2, mask=mask, other=0)
d = tl.load(idx_ptr + offs * 4 + 3, mask=mask, other=0)
p1x = tl.load(xyz_ptr + a * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + a * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + a * 3 + 2, mask=mask, other=0.0)
p2x = tl.load(xyz_ptr + b * 3 + 0, mask=mask, other=0.0)
p2y = tl.load(xyz_ptr + b * 3 + 1, mask=mask, other=0.0)
p2z = tl.load(xyz_ptr + b * 3 + 2, mask=mask, other=0.0)
p3x = tl.load(xyz_ptr + c * 3 + 0, mask=mask, other=0.0)
p3y = tl.load(xyz_ptr + c * 3 + 1, mask=mask, other=0.0)
p3z = tl.load(xyz_ptr + c * 3 + 2, mask=mask, other=0.0)
p4x = tl.load(xyz_ptr + d * 3 + 0, mask=mask, other=0.0)
p4y = tl.load(xyz_ptr + d * 3 + 1, mask=mask, other=0.0)
p4z = tl.load(xyz_ptr + d * 3 + 2, mask=mask, other=0.0)
(omega, F1x, F1y, F1z, F2x, F2y, F2z,
F3x, F3y, F3z, F4x, F4y, F4z) = dihedral_and_grad(
p1x, p1y, p1z, p2x, p2y, p2z, p3x, p3y, p3z, p4x, p4y, p4z,
)
ref_rad = tl.load(ref_deg_ptr + offs, mask=mask, other=0.0) * DEG2RAD
diff = omega - ref_rad
period = tl.load(period_ptr + offs, mask=mask, other=1)
period_f = period.to(tl.float32)
step = (2.0 * 3.141592653589793) / period_f
LARGE = 1e30
best_dev = tl.zeros_like(omega)
best_abs = tl.full(omega.shape, LARGE, tl.float32)
for k in tl.static_range(MAX_PERIOD):
valid = k < period
cand = diff + k * step
cand = cand - (2.0 * 3.141592653589793) * libdevice.round(cand / (2.0 * 3.141592653589793))
a_cand = tl.abs(cand)
is_better = valid & (a_cand < best_abs)
best_dev = tl.where(is_better, cand, best_dev)
best_abs = tl.where(is_better, a_cand, best_abs)
sig_deg = tl.load(sig_deg_ptr + offs, mask=mask, other=1.0)
sig_rad = sig_deg * DEG2RAD
kappa = 1.0 / (sig_rad * sig_rad)
kappa = tl.minimum(tl.maximum(kappa, 1e-3), 1e4)
# dNLL/d(best_dev) = κ · sin(best_dev). d(best_dev)/d(diff) = 1 (the
# selected branch passes through up to the [-π, π] wrap, which has
# derivative 1 a.e.). d(diff)/d(ω) = 1, so dNLL/d(ω) = κ · sin(best_dev).
coef = grad_out * kappa * tl.sin(best_dev)
g1x = coef * F1x; g1y = coef * F1y; g1z = coef * F1z
g2x = coef * F2x; g2y = coef * F2y; g2z = coef * F2z
g3x = coef * F3x; g3y = coef * F3y; g3z = coef * F3z
g4x = coef * F4x; g4y = coef * F4y; g4z = coef * F4z
tl.atomic_add(dxyz_ptr + a * 3 + 0, g1x, mask=mask)
tl.atomic_add(dxyz_ptr + a * 3 + 1, g1y, mask=mask)
tl.atomic_add(dxyz_ptr + a * 3 + 2, g1z, mask=mask)
tl.atomic_add(dxyz_ptr + b * 3 + 0, g2x, mask=mask)
tl.atomic_add(dxyz_ptr + b * 3 + 1, g2y, mask=mask)
tl.atomic_add(dxyz_ptr + b * 3 + 2, g2z, mask=mask)
tl.atomic_add(dxyz_ptr + c * 3 + 0, g3x, mask=mask)
tl.atomic_add(dxyz_ptr + c * 3 + 1, g3y, mask=mask)
tl.atomic_add(dxyz_ptr + c * 3 + 2, g3z, mask=mask)
tl.atomic_add(dxyz_ptr + d * 3 + 0, g4x, mask=mask)
tl.atomic_add(dxyz_ptr + d * 3 + 1, g4y, mask=mask)
tl.atomic_add(dxyz_ptr + d * 3 + 2, g4z, mask=mask)
class _TorsionUnimodalMathTriton(torch.autograd.Function):
# ``MAX_PERIOD`` is a Triton constexpr (compile-time loop bound).
# We fix it to 6 — the realistic upper bound for protein restraint
# libraries (covers 1-, 2-, 3-, 4-, 6-fold symmetries used by Monomer
# Library / cctbx geostd). Using a fixed value avoids reading
# ``periods.max().item()`` per call, which forced a host sync and
# blocked CUDA Graph capture. The kernel masks out unused period
# slots via ``i < period`` so any period ≤ 6 produces the correct
# answer with at most 6 candidates per restraint.
_FIXED_MAX_PERIOD = 6
@staticmethod
def forward(ctx, xyz, idx, references_deg, sigmas_deg, periods):
assert xyz.is_cuda and xyz.dtype == torch.float32
N = idx.shape[0]
max_period = _TorsionUnimodalMathTriton._FIXED_MAX_PERIOD
periods_i32 = periods.clamp(min=1).to(torch.int32).contiguous()
nll = torch.empty(N, dtype=xyz.dtype, device=xyz.device)
BLOCK = 128
grid = (triton.cdiv(N, BLOCK),)
_torsion_uni_fwd_kernel[grid](
xyz, idx, references_deg, sigmas_deg, periods_i32, nll,
N=N, MAX_PERIOD=max_period,
LOG_2PI=_LOG_2PI, DEG2RAD=_DEG2RAD, BLOCK=BLOCK,
)
ctx.save_for_backward(xyz, idx, references_deg, sigmas_deg, periods_i32)
ctx.max_period = max_period
return nll.sum()
@staticmethod
def backward(ctx, grad_out):
xyz, idx, refs, sigs, periods_i32 = ctx.saved_tensors
N = idx.shape[0]
dxyz = torch.zeros_like(xyz)
BLOCK = 128
grid = (triton.cdiv(N, BLOCK),)
_torsion_uni_bwd_kernel[grid](
xyz, idx, refs, sigs, periods_i32, grad_out, dxyz,
N=N, MAX_PERIOD=ctx.max_period,
DEG2RAD=_DEG2RAD, BLOCK=BLOCK,
)
return dxyz, None, None, None, None
[docs]
def torsion_unimodal_full_math_triton(xyz, idx, references_deg, sigmas_deg, periods):
"""Triton-backed full unimodal torsion NLL (dihedral + periodic wrap +
von Mises NLL + sum), with analytic backward through the dihedral
formula. Inputs match ``Restraints.restraints['torsion']['all']``:
Parameters
----------
xyz : (N_atoms, 3) float32 CUDA
idx : (N, 4) int64 atom indices
references_deg : (N,) float32 — target angles in degrees
sigmas_deg : (N,) float32 — sigmas in degrees
periods : (N,) int — n-fold periodicity (≥ 1)
"""
return _TorsionUnimodalMathTriton.apply(
xyz, idx, references_deg, sigmas_deg, periods,
)