"""Triton forward + analytic backward for the non-bonded prolsq VDW target.
The forward fuses the gather + (optional) cartesian symmetry transform +
prolsq shape energy + per-pair constant into one kernel. The backward
chains through:
v(d) = max(0, d_vdw + buf − d)
E(v) = c_rep · v^r_exp (when v > 0)
dE/dv = c_rep · r_exp · v^(r_exp − 1)
dv/dd = −1 (when v > 0, else 0)
dd/dpos1 = −diff/d dd/dpos2 = +diff/d
grad_pos2_source = R_cartᵀ · grad_pos2 (with symmetry)
Cartesian symmetry transforms (M·R·M⁻¹ and M·t) are precomputed once on
the host, so the kernel only does a 3×3 matvec (forward) and 3×3
transposed matvec (backward) per pair.
"""
from __future__ import annotations
import math
import torch
import triton
import triton.language as tl
_LOG_2PI = float(math.log(2.0 * math.pi))
@triton.jit
def _nb_fwd_kernel(
xyz_ptr, # (N_atoms, 3)
idx_ptr, # (N_pairs, 2)
min_d_ptr, # (N_pairs,)
symop_idx_ptr, # (N_pairs,) int32
cart_sym_mat_ptr, # (n_symops, 3, 3) -- M R M^-1
cart_sym_off_ptr, # (n_symops, 3) -- M t
cell_off_cart_ptr, # (N_pairs, 3) -- M @ cell_offsets
out_ptr, # (N_pairs,)
c_rep_ptr, # 0-D tensor (was C_REP constexpr — see file docstring)
r_exp_ptr, # 0-D tensor (was R_EXP constexpr)
log_sig_plus_ptr, # 0-D tensor: log(sigma_vdw) + 0.5*log(2pi)
has_symmetry: tl.constexpr,
N: tl.constexpr,
BUFFER: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
# Load the constant scalars once per block instead of receiving them
# as kernel-arg-by-value (which forced a host ``.item()``).
C_REP = tl.load(c_rep_ptr)
R_EXP = tl.load(r_exp_ptr)
LOG_SIG_PLUS_HALF_LOG_2PI = tl.load(log_sig_plus_ptr)
i = tl.load(idx_ptr + offs * 2 + 0, mask=mask, other=0)
j = tl.load(idx_ptr + offs * 2 + 1, mask=mask, other=0)
p1x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
msx = tl.load(xyz_ptr + j * 3 + 0, mask=mask, other=0.0)
msy = tl.load(xyz_ptr + j * 3 + 1, mask=mask, other=0.0)
msz = tl.load(xyz_ptr + j * 3 + 2, mask=mask, other=0.0)
if has_symmetry:
s = tl.load(symop_idx_ptr + offs, mask=mask, other=0)
M00 = tl.load(cart_sym_mat_ptr + s * 9 + 0, mask=mask, other=1.0)
M01 = tl.load(cart_sym_mat_ptr + s * 9 + 1, mask=mask, other=0.0)
M02 = tl.load(cart_sym_mat_ptr + s * 9 + 2, mask=mask, other=0.0)
M10 = tl.load(cart_sym_mat_ptr + s * 9 + 3, mask=mask, other=0.0)
M11 = tl.load(cart_sym_mat_ptr + s * 9 + 4, mask=mask, other=1.0)
M12 = tl.load(cart_sym_mat_ptr + s * 9 + 5, mask=mask, other=0.0)
M20 = tl.load(cart_sym_mat_ptr + s * 9 + 6, mask=mask, other=0.0)
M21 = tl.load(cart_sym_mat_ptr + s * 9 + 7, mask=mask, other=0.0)
M22 = tl.load(cart_sym_mat_ptr + s * 9 + 8, mask=mask, other=1.0)
ox = tl.load(cart_sym_off_ptr + s * 3 + 0, mask=mask, other=0.0)
oy = tl.load(cart_sym_off_ptr + s * 3 + 1, mask=mask, other=0.0)
oz = tl.load(cart_sym_off_ptr + s * 3 + 2, mask=mask, other=0.0)
cox = tl.load(cell_off_cart_ptr + offs * 3 + 0, mask=mask, other=0.0)
coy = tl.load(cell_off_cart_ptr + offs * 3 + 1, mask=mask, other=0.0)
coz = tl.load(cell_off_cart_ptr + offs * 3 + 2, mask=mask, other=0.0)
p2x = M00 * msx + M01 * msy + M02 * msz + ox + cox
p2y = M10 * msx + M11 * msy + M12 * msz + oy + coy
p2z = M20 * msx + M21 * msy + M22 * msz + oz + coz
else:
p2x = msx; p2y = msy; p2z = msz
dx = p2x - p1x; dy = p2y - p1y; dz = p2z - p1z
d = tl.sqrt(dx * dx + dy * dy + dz * dz + 1e-8)
md = tl.load(min_d_ptr + offs, mask=mask, other=0.0)
v = md + BUFFER - d
v = tl.where(v > 0.0, v, 0.0)
e = C_REP * tl.exp(R_EXP * tl.log(v + 1e-30)) * tl.where(v > 0.0, 1.0, 0.0)
nll = e + LOG_SIG_PLUS_HALF_LOG_2PI
tl.store(out_ptr + offs, nll, mask=mask)
@triton.jit
def _nb_bwd_kernel(
xyz_ptr,
idx_ptr,
min_d_ptr,
symop_idx_ptr,
cart_sym_mat_ptr,
cart_sym_off_ptr,
cell_off_cart_ptr,
grad_out_ptr, # 0-D tensor — loaded in-kernel
dxyz_ptr, # (N_atoms, 3)
c_rep_ptr, # 0-D tensor (was C_REP constexpr)
r_exp_ptr, # 0-D tensor (was R_EXP constexpr)
has_symmetry: tl.constexpr,
N: tl.constexpr,
BUFFER: tl.constexpr,
BLOCK: tl.constexpr,
):
"""Analytic backward for the prolsq nonbonded NLL.
For each pair: recompute v and d, then
coef = grad_out · c_rep · r_exp · v^(r_exp−1) when v > 0, else 0
grad_pos1 = -coef · (diff/d)
grad_pos2 = +coef · (diff/d)
grad_pos2_source = R_cartᵀ · grad_pos2 (with symmetry; identity otherwise)
"""
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
grad_out = tl.load(grad_out_ptr)
C_REP = tl.load(c_rep_ptr)
R_EXP = tl.load(r_exp_ptr)
i = tl.load(idx_ptr + offs * 2 + 0, mask=mask, other=0)
j = tl.load(idx_ptr + offs * 2 + 1, mask=mask, other=0)
p1x = tl.load(xyz_ptr + i * 3 + 0, mask=mask, other=0.0)
p1y = tl.load(xyz_ptr + i * 3 + 1, mask=mask, other=0.0)
p1z = tl.load(xyz_ptr + i * 3 + 2, mask=mask, other=0.0)
msx = tl.load(xyz_ptr + j * 3 + 0, mask=mask, other=0.0)
msy = tl.load(xyz_ptr + j * 3 + 1, mask=mask, other=0.0)
msz = tl.load(xyz_ptr + j * 3 + 2, mask=mask, other=0.0)
if has_symmetry:
s = tl.load(symop_idx_ptr + offs, mask=mask, other=0)
M00 = tl.load(cart_sym_mat_ptr + s * 9 + 0, mask=mask, other=1.0)
M01 = tl.load(cart_sym_mat_ptr + s * 9 + 1, mask=mask, other=0.0)
M02 = tl.load(cart_sym_mat_ptr + s * 9 + 2, mask=mask, other=0.0)
M10 = tl.load(cart_sym_mat_ptr + s * 9 + 3, mask=mask, other=0.0)
M11 = tl.load(cart_sym_mat_ptr + s * 9 + 4, mask=mask, other=1.0)
M12 = tl.load(cart_sym_mat_ptr + s * 9 + 5, mask=mask, other=0.0)
M20 = tl.load(cart_sym_mat_ptr + s * 9 + 6, mask=mask, other=0.0)
M21 = tl.load(cart_sym_mat_ptr + s * 9 + 7, mask=mask, other=0.0)
M22 = tl.load(cart_sym_mat_ptr + s * 9 + 8, mask=mask, other=1.0)
ox = tl.load(cart_sym_off_ptr + s * 3 + 0, mask=mask, other=0.0)
oy = tl.load(cart_sym_off_ptr + s * 3 + 1, mask=mask, other=0.0)
oz = tl.load(cart_sym_off_ptr + s * 3 + 2, mask=mask, other=0.0)
cox = tl.load(cell_off_cart_ptr + offs * 3 + 0, mask=mask, other=0.0)
coy = tl.load(cell_off_cart_ptr + offs * 3 + 1, mask=mask, other=0.0)
coz = tl.load(cell_off_cart_ptr + offs * 3 + 2, mask=mask, other=0.0)
p2x = M00 * msx + M01 * msy + M02 * msz + ox + cox
p2y = M10 * msx + M11 * msy + M12 * msz + oy + coy
p2z = M20 * msx + M21 * msy + M22 * msz + oz + coz
else:
# dummies so the compiler is happy when has_symmetry=False
M00 = 1.0; M01 = 0.0; M02 = 0.0
M10 = 0.0; M11 = 1.0; M12 = 0.0
M20 = 0.0; M21 = 0.0; M22 = 1.0
p2x = msx; p2y = msy; p2z = msz
dx = p2x - p1x; dy = p2y - p1y; dz = p2z - p1z
d = tl.sqrt(dx * dx + dy * dy + dz * dz + 1e-8)
md = tl.load(min_d_ptr + offs, mask=mask, other=0.0)
v = md + BUFFER - d
active = v > 0.0
v_safe = tl.where(active, v, 1.0) # avoid log(0)
# dE/dv = C_REP * R_EXP * v^(R_EXP - 1)
dEdv = C_REP * R_EXP * tl.exp((R_EXP - 1.0) * tl.log(v_safe))
dEdv = tl.where(active, dEdv, 0.0)
coef = grad_out * dEdv / d # multiplier on (dx, dy, dz)
# dE/dpos1 = -coef * (diff) (because dd/dpos1 = -diff/d and dv/dd = -1)
# dE/dpos2 = +coef * (diff)
g2x = coef * dx
g2y = coef * dy
g2z = coef * dz
if has_symmetry:
# grad_pos2_source = R_cart^T @ grad_pos2
gsx = M00 * g2x + M10 * g2y + M20 * g2z
gsy = M01 * g2x + M11 * g2y + M21 * g2z
gsz = M02 * g2x + M12 * g2y + M22 * g2z
else:
gsx = g2x; gsy = g2y; gsz = g2z
# ∂E/∂pos1 = +coef · diff (since dd/dpos1 = -diff/d and dv/dd = -1)
# ∂E/∂pos2 = -coef · diff
# pos2 = R_cart · pos2_source + offset ⇒ ∂E/∂pos2_source = R_cartᵀ · ∂E/∂pos2
tl.atomic_add(dxyz_ptr + j * 3 + 0, -gsx, mask=mask)
tl.atomic_add(dxyz_ptr + j * 3 + 1, -gsy, mask=mask)
tl.atomic_add(dxyz_ptr + j * 3 + 2, -gsz, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 0, g2x, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 1, g2y, mask=mask)
tl.atomic_add(dxyz_ptr + i * 3 + 2, g2z, mask=mask)
def _build_cartesian_symops(symop_matrices, symop_translations,
fractional_matrix, inv_fractional_matrix):
"""Precompute the cartesian rotation+translation per symop: M·R·M⁻¹, M·t."""
M = fractional_matrix
Minv = inv_fractional_matrix
cart_mat = torch.einsum("ij,sjk,kl->sil",
M, symop_matrices.to(M.dtype), Minv)
cart_off = symop_translations.to(M.dtype) @ M.T
return cart_mat.contiguous(), cart_off.contiguous()
class _NonbondedHeavyMathTriton(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz, indices, min_distances,
symop_indices, cell_offsets,
symop_matrices, symop_translations,
fractional_matrix, inv_fractional_matrix,
c_rep, r_exp, buffer_, sigma_vdw):
assert xyz.is_cuda and xyz.dtype == torch.float32
N = indices.shape[0]
nll = torch.empty(N, dtype=xyz.dtype, device=xyz.device)
has_sym = (
symop_indices is not None
and symop_indices.numel() > 0
and not bool((symop_indices == 0).all())
)
if has_sym:
cart_mat, cart_off = _build_cartesian_symops(
symop_matrices, symop_translations,
fractional_matrix, inv_fractional_matrix,
)
cell_off_cart = (cell_offsets.to(xyz.dtype)
@ fractional_matrix.T).contiguous()
symop_i32 = symop_indices.to(torch.int32).contiguous()
else:
cart_mat = torch.zeros(1, 3, 3, device=xyz.device, dtype=xyz.dtype)
cart_off = torch.zeros(1, 3, device=xyz.device, dtype=xyz.dtype)
cell_off_cart = torch.zeros(N, 3, device=xyz.device, dtype=xyz.dtype)
symop_i32 = torch.zeros(N, device=xyz.device, dtype=torch.int32)
# Scalar parameters become 0-D *device* tensors so neither the
# forward nor backward needs a host ``.item()`` synchronize.
# Use a helper that skips .to() when device+dtype already match
# — calling .to() on a same-device tensor still allocates a new
# tensor (forbidden during CUDA Graph capture).
def _as_device_tensor(t, ref):
if t.device == ref.device and t.dtype == ref.dtype:
return t if t.is_contiguous() else t.contiguous()
return t.to(device=ref.device, dtype=ref.dtype).contiguous()
sigma_vdw_t = _as_device_tensor(sigma_vdw, xyz)
c_rep_t = _as_device_tensor(c_rep, xyz)
r_exp_t = _as_device_tensor(r_exp, xyz)
log_sig_plus = torch.log(sigma_vdw_t) + 0.5 * _LOG_2PI
buf_f = float(buffer_)
BLOCK = 256
grid = (triton.cdiv(N, BLOCK),)
_nb_fwd_kernel[grid](
xyz, indices, min_distances, symop_i32, cart_mat, cart_off,
cell_off_cart, nll,
c_rep_t, r_exp_t, log_sig_plus,
has_symmetry=bool(has_sym),
N=N, BUFFER=buf_f, BLOCK=BLOCK,
)
ctx.save_for_backward(
xyz, indices, min_distances,
symop_i32, cart_mat, cart_off, cell_off_cart,
c_rep_t, r_exp_t,
)
ctx.buffer = buf_f
ctx.has_sym = bool(has_sym)
return nll.sum()
@staticmethod
def backward(ctx, grad_out):
(xyz, indices, min_distances, symop_i32,
cart_mat, cart_off, cell_off_cart, c_rep_t, r_exp_t) = ctx.saved_tensors
N = indices.shape[0]
dxyz = torch.zeros_like(xyz)
BLOCK = 256
grid = (triton.cdiv(N, BLOCK),)
_nb_bwd_kernel[grid](
xyz, indices, min_distances, symop_i32,
cart_mat, cart_off, cell_off_cart,
grad_out, dxyz,
c_rep_t, r_exp_t,
has_symmetry=ctx.has_sym,
N=N, BUFFER=ctx.buffer,
BLOCK=BLOCK,
)
return (dxyz,) + (None,) * 12
[docs]
def nonbonded_heavy_math_triton(
xyz, indices, min_distances,
symop_indices, cell_offsets,
symop_matrices, symop_translations,
fractional_matrix, inv_fractional_matrix,
c_rep, r_exp, buffer_, sigma_vdw,
):
"""Triton-backed heavy-heavy VDW prolsq NLL with analytic backward."""
return _NonbondedHeavyMathTriton.apply(
xyz, indices, min_distances,
symop_indices, cell_offsets,
symop_matrices, symop_translations,
fractional_matrix, inv_fractional_matrix,
c_rep, r_exp, buffer_, sigma_vdw,
)