"""
Separable Gaussian density splatting via Triton.
Factorizes exp(-alpha * f^T G f) into 1D Gaussian tables along each
fractional axis, with optional 2D cross-term corrections for non-orthogonal
cells. One program per atom. All 1D/2D tables live in a small per-atom
scratch buffer (~0.8-3 KB) that stays hot in L1 cache.
Eliminates the real_space_grid tensor (~500 MB) and all PBC matrix
operations that the fused kernel requires.
Forward + backward kernels with full autograd support for xyz, b, occ.
For non-orthogonal cells, uses combined exponent exp(-alpha*r²) to avoid
numerical overflow from separate diagonal × cross-term exp() products.
"""
import math
from typing import Optional
import torch
import triton
import triton.language as tl
# =============================================================================
# Constants
# =============================================================================
PI: float = 3.141592653589793
PI_SQ: float = PI * PI
PI_1P5: float = PI * math.sqrt(PI) # pi^1.5 ≈ 5.568327996831708
# =============================================================================
# Forward kernel
# =============================================================================
@triton.jit
def _separable_fwd_kernel(
# Pointers
density_map_ptr, # (nx*ny*nz,) float32
xyz_ptr, # (N_atoms, 3) float32
b_ptr, # (N_atoms,) float32
A_ptr, # (N_atoms, 5) float32
B_ptr, # (N_atoms, 5) float32
occ_ptr, # (N_atoms,) float32
offsets_ptr, # (N_sphere, 3) int16 — sphere voxel offsets
inv_frac_ptr, # (9,) float32 row-major
scratch_ptr, # (N_atoms, SCRATCH_SIZE) float32
# Metric tensor components
G11, G22, G33,
G12, G13, G23,
# Grid spacing (fractional)
inv_grid_x, inv_grid_y, inv_grid_z,
# Dimensions
nx: tl.constexpr,
ny: tl.constexpr,
nz: tl.constexpr,
N_sphere: tl.constexpr,
N_AXIS: tl.constexpr,
half_n: tl.constexpr,
SCRATCH_PER_ATOM: tl.constexpr,
BLOCK_V: tl.constexpr,
# Cross-term flags
COMPUTE_XY: tl.constexpr,
COMPUTE_XZ: tl.constexpr,
COMPUTE_YZ: tl.constexpr,
STORE_CROSS_TABLES: tl.constexpr,
):
"""One program per atom. Builds 1D tables, gathers per sphere voxel."""
atom = tl.program_id(0)
# ---- Stage 1: Load per-atom parameters ----
b_iso = tl.load(b_ptr + atom)
occ = tl.load(occ_ptr + atom)
ax = tl.load(xyz_ptr + atom * 3 + 0)
ay = tl.load(xyz_ptr + atom * 3 + 1)
az = tl.load(xyz_ptr + atom * 3 + 2)
# ITC92 coefficients (5 components)
A0 = tl.load(A_ptr + atom * 5 + 0)
A1 = tl.load(A_ptr + atom * 5 + 1)
A2 = tl.load(A_ptr + atom * 5 + 2)
A3 = tl.load(A_ptr + atom * 5 + 3)
A4 = tl.load(A_ptr + atom * 5 + 4)
B0 = tl.load(B_ptr + atom * 5 + 0)
B1 = tl.load(B_ptr + atom * 5 + 1)
B2 = tl.load(B_ptr + atom * 5 + 2)
B3 = tl.load(B_ptr + atom * 5 + 3)
B4 = tl.load(B_ptr + atom * 5 + 4)
# B_total, alpha, A_norm per component
Bt0 = tl.maximum((B0 + b_iso) * 0.25, 0.1)
Bt1 = tl.maximum((B1 + b_iso) * 0.25, 0.1)
Bt2 = tl.maximum((B2 + b_iso) * 0.25, 0.1)
Bt3 = tl.maximum((B3 + b_iso) * 0.25, 0.1)
Bt4 = tl.maximum((B4 + b_iso) * 0.25, 0.1)
pi_sq: tl.constexpr = 9.869604401089358
pi_1p5: tl.constexpr = 5.568327996831708
al0 = pi_sq / Bt0
al1 = pi_sq / Bt1
al2 = pi_sq / Bt2
al3 = pi_sq / Bt3
al4 = pi_sq / Bt4
An0 = A0 * occ * pi_1p5 / (Bt0 * tl.sqrt(Bt0))
An1 = A1 * occ * pi_1p5 / (Bt1 * tl.sqrt(Bt1))
An2 = A2 * occ * pi_1p5 / (Bt2 * tl.sqrt(Bt2))
An3 = A3 * occ * pi_1p5 / (Bt3 * tl.sqrt(Bt3))
An4 = A4 * occ * pi_1p5 / (Bt4 * tl.sqrt(Bt4))
# ---- Cartesian → fractional conversion ----
if0 = tl.load(inv_frac_ptr + 0)
if1 = tl.load(inv_frac_ptr + 1)
if2 = tl.load(inv_frac_ptr + 2)
if3 = tl.load(inv_frac_ptr + 3)
if4 = tl.load(inv_frac_ptr + 4)
if5 = tl.load(inv_frac_ptr + 5)
if6 = tl.load(inv_frac_ptr + 6)
if7 = tl.load(inv_frac_ptr + 7)
if8 = tl.load(inv_frac_ptr + 8)
frac_x = ax * if0 + ay * if1 + az * if2
frac_y = ax * if3 + ay * if4 + az * if5
frac_z = ax * if6 + ay * if7 + az * if8
# Wrap to [0, 1)
frac_x = frac_x - tl.extra.cuda.libdevice.floor(frac_x)
frac_y = frac_y - tl.extra.cuda.libdevice.floor(frac_y)
frac_z = frac_z - tl.extra.cuda.libdevice.floor(frac_z)
# Grid anchor (nearest grid index)
cix = tl.extra.cuda.libdevice.round(frac_x * nx).to(tl.int32)
ciy = tl.extra.cuda.libdevice.round(frac_y * ny).to(tl.int32)
ciz = tl.extra.cuda.libdevice.round(frac_z * nz).to(tl.int32)
# Sub-grid offset (fractional)
sub_x = frac_x - cix.to(tl.float32) * inv_grid_x
sub_y = frac_y - ciy.to(tl.float32) * inv_grid_y
sub_z = frac_z - ciz.to(tl.float32) * inv_grid_z
# ---- Stage 2: Build 1D tables ----
# Scratch layout: [diag_x(5*N), diag_y(5*N), diag_z(5*N),
# delta_x(N), delta_y(N), delta_z(N)]
# When using combined-exponent path (a), only deltas are read per voxel;
# skip the 15 1D exp-table computations (saves 15*N_AXIS exp() calls).
_USE_COMBINED: tl.constexpr = (
not STORE_CROSS_TABLES and (COMPUTE_XY or COMPUTE_XZ or COMPUTE_YZ)
)
base = atom * SCRATCH_PER_ATOM
axis_idx = tl.arange(0, N_AXIS)
half_n_f: tl.constexpr = half_n # float version
# --- Axis X ---
delta_x = (axis_idx.to(tl.float32) - half_n_f) * inv_grid_x - sub_x
tl.store(scratch_ptr + base + 15 * N_AXIS + axis_idx, delta_x) # store deltas
if not _USE_COMBINED:
dx2 = delta_x * delta_x
diag_x0 = tl.exp(-al0 * G11 * dx2)
diag_x1 = tl.exp(-al1 * G11 * dx2)
diag_x2 = tl.exp(-al2 * G11 * dx2)
diag_x3 = tl.exp(-al3 * G11 * dx2)
diag_x4 = tl.exp(-al4 * G11 * dx2)
tl.store(scratch_ptr + base + 0 * N_AXIS + axis_idx, diag_x0)
tl.store(scratch_ptr + base + 1 * N_AXIS + axis_idx, diag_x1)
tl.store(scratch_ptr + base + 2 * N_AXIS + axis_idx, diag_x2)
tl.store(scratch_ptr + base + 3 * N_AXIS + axis_idx, diag_x3)
tl.store(scratch_ptr + base + 4 * N_AXIS + axis_idx, diag_x4)
# --- Axis Y ---
delta_y = (axis_idx.to(tl.float32) - half_n_f) * inv_grid_y - sub_y
tl.store(scratch_ptr + base + 16 * N_AXIS + axis_idx, delta_y)
if not _USE_COMBINED:
dy2 = delta_y * delta_y
diag_y0 = tl.exp(-al0 * G22 * dy2)
diag_y1 = tl.exp(-al1 * G22 * dy2)
diag_y2 = tl.exp(-al2 * G22 * dy2)
diag_y3 = tl.exp(-al3 * G22 * dy2)
diag_y4 = tl.exp(-al4 * G22 * dy2)
tl.store(scratch_ptr + base + 5 * N_AXIS + axis_idx, diag_y0)
tl.store(scratch_ptr + base + 6 * N_AXIS + axis_idx, diag_y1)
tl.store(scratch_ptr + base + 7 * N_AXIS + axis_idx, diag_y2)
tl.store(scratch_ptr + base + 8 * N_AXIS + axis_idx, diag_y3)
tl.store(scratch_ptr + base + 9 * N_AXIS + axis_idx, diag_y4)
# --- Axis Z ---
delta_z = (axis_idx.to(tl.float32) - half_n_f) * inv_grid_z - sub_z
tl.store(scratch_ptr + base + 17 * N_AXIS + axis_idx, delta_z)
if not _USE_COMBINED:
dz2 = delta_z * delta_z
diag_z0 = tl.exp(-al0 * G33 * dz2)
diag_z1 = tl.exp(-al1 * G33 * dz2)
diag_z2 = tl.exp(-al2 * G33 * dz2)
diag_z3 = tl.exp(-al3 * G33 * dz2)
diag_z4 = tl.exp(-al4 * G33 * dz2)
tl.store(scratch_ptr + base + 10 * N_AXIS + axis_idx, diag_z0)
tl.store(scratch_ptr + base + 11 * N_AXIS + axis_idx, diag_z1)
tl.store(scratch_ptr + base + 12 * N_AXIS + axis_idx, diag_z2)
tl.store(scratch_ptr + base + 13 * N_AXIS + axis_idx, diag_z3)
tl.store(scratch_ptr + base + 14 * N_AXIS + axis_idx, diag_z4)
# ---- Stage 3: 2D cross-term tables (conditional) ----
# cross_base starts after the 1D tables + deltas (18 * N_AXIS)
cross_base = base + 18 * N_AXIS
if STORE_CROSS_TABLES:
if COMPUTE_XY:
# cross_xy[c][i*N_AXIS + j] for each component c
# Vectorised over N_AXIS² elements
idx_2d = tl.arange(0, N_AXIS * N_AXIS)
ii = idx_2d // N_AXIS # row (x index)
jj = idx_2d % N_AXIS # col (y index)
dx_i = (ii.to(tl.float32) - half_n_f) * inv_grid_x - sub_x
dy_j = (jj.to(tl.float32) - half_n_f) * inv_grid_y - sub_y
prod_xy = dx_i * dy_j
off_xy = cross_base
tl.store(scratch_ptr + off_xy + 0 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al0 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 1 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al1 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 2 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al2 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 3 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al3 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 4 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al4 * 2.0 * G12 * prod_xy))
cross_base = cross_base + 5 * N_AXIS * N_AXIS
if COMPUTE_XZ:
idx_2d = tl.arange(0, N_AXIS * N_AXIS)
ii = idx_2d // N_AXIS
kk = idx_2d % N_AXIS
dx_i = (ii.to(tl.float32) - half_n_f) * inv_grid_x - sub_x
dz_k = (kk.to(tl.float32) - half_n_f) * inv_grid_z - sub_z
prod_xz = dx_i * dz_k
off_xz = cross_base
tl.store(scratch_ptr + off_xz + 0 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al0 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 1 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al1 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 2 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al2 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 3 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al3 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 4 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al4 * 2.0 * G13 * prod_xz))
cross_base = cross_base + 5 * N_AXIS * N_AXIS
if COMPUTE_YZ:
idx_2d = tl.arange(0, N_AXIS * N_AXIS)
jj = idx_2d // N_AXIS
kk = idx_2d % N_AXIS
dy_j = (jj.to(tl.float32) - half_n_f) * inv_grid_y - sub_y
dz_k = (kk.to(tl.float32) - half_n_f) * inv_grid_z - sub_z
prod_yz = dy_j * dz_k
off_yz = cross_base
tl.store(scratch_ptr + off_yz + 0 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al0 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 1 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al1 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 2 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al2 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 3 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al3 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 4 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al4 * 2.0 * G23 * prod_yz))
# ---- Stage 4: Assemble sphere voxels and scatter ----
# Precompute cross-term table base offsets (compile-time)
cross_off_first = base + 18 * N_AXIS
# For STORE_CROSS_TABLES: tables laid out in order XY, XZ, YZ
# (only present if the corresponding COMPUTE flag is True)
v_offsets = tl.arange(0, BLOCK_V)
for v_start in range(0, N_sphere, BLOCK_V):
v = v_start + v_offsets
mask = v < N_sphere
# Load sphere offsets (int16 → int32)
si = tl.load(offsets_ptr + v * 3 + 0, mask=mask, other=0).to(tl.int32)
sj = tl.load(offsets_ptr + v * 3 + 1, mask=mask, other=0).to(tl.int32)
sk = tl.load(offsets_ptr + v * 3 + 2, mask=mask, other=0).to(tl.int32)
# Table indices
ti = si + half_n
tj = sj + half_n
tk = sk + half_n
# Compute per-component density contributions c0-c4.
#
# Two paths:
# (a) Combined exponent: compute full r²=d^T G d from stored deltas
# and use a single exp(-alpha*r²) per component.
# Avoids exp(-big_diag)*exp(+big_cross) = 0*inf = NaN.
# (b) Pure separable or stored cross tables: load 1D diag values
# and (optionally) multiply by pre-stored cross-term tables.
if _USE_COMBINED:
# Path (a): combined exponent from stored deltas
dxi = tl.load(scratch_ptr + base + 15 * N_AXIS + ti, mask=mask, other=0.0)
dyj = tl.load(scratch_ptr + base + 16 * N_AXIS + tj, mask=mask, other=0.0)
dzk = tl.load(scratch_ptr + base + 17 * N_AXIS + tk, mask=mask, other=0.0)
r2 = G11 * dxi * dxi + G22 * dyj * dyj + G33 * dzk * dzk
if COMPUTE_XY:
r2 = r2 + 2.0 * G12 * dxi * dyj
if COMPUTE_XZ:
r2 = r2 + 2.0 * G13 * dxi * dzk
if COMPUTE_YZ:
r2 = r2 + 2.0 * G23 * dyj * dzk
c0 = An0 * tl.exp(-al0 * r2)
c1 = An1 * tl.exp(-al1 * r2)
c2 = An2 * tl.exp(-al2 * r2)
c3 = An3 * tl.exp(-al3 * r2)
c4 = An4 * tl.exp(-al4 * r2)
else:
# Path (b): separable 1D diagonal tables
vx0 = tl.load(scratch_ptr + base + 0 * N_AXIS + ti, mask=mask, other=0.0)
vy0 = tl.load(scratch_ptr + base + 5 * N_AXIS + tj, mask=mask, other=0.0)
vz0 = tl.load(scratch_ptr + base + 10 * N_AXIS + tk, mask=mask, other=0.0)
c0 = An0 * vx0 * vy0 * vz0
vx1 = tl.load(scratch_ptr + base + 1 * N_AXIS + ti, mask=mask, other=0.0)
vy1 = tl.load(scratch_ptr + base + 6 * N_AXIS + tj, mask=mask, other=0.0)
vz1 = tl.load(scratch_ptr + base + 11 * N_AXIS + tk, mask=mask, other=0.0)
c1 = An1 * vx1 * vy1 * vz1
vx2 = tl.load(scratch_ptr + base + 2 * N_AXIS + ti, mask=mask, other=0.0)
vy2 = tl.load(scratch_ptr + base + 7 * N_AXIS + tj, mask=mask, other=0.0)
vz2 = tl.load(scratch_ptr + base + 12 * N_AXIS + tk, mask=mask, other=0.0)
c2 = An2 * vx2 * vy2 * vz2
vx3 = tl.load(scratch_ptr + base + 3 * N_AXIS + ti, mask=mask, other=0.0)
vy3 = tl.load(scratch_ptr + base + 8 * N_AXIS + tj, mask=mask, other=0.0)
vz3 = tl.load(scratch_ptr + base + 13 * N_AXIS + tk, mask=mask, other=0.0)
c3 = An3 * vx3 * vy3 * vz3
vx4 = tl.load(scratch_ptr + base + 4 * N_AXIS + ti, mask=mask, other=0.0)
vy4 = tl.load(scratch_ptr + base + 9 * N_AXIS + tj, mask=mask, other=0.0)
vz4 = tl.load(scratch_ptr + base + 14 * N_AXIS + tk, mask=mask, other=0.0)
c4 = An4 * vx4 * vy4 * vz4
if STORE_CROSS_TABLES:
ct_base = cross_off_first
if COMPUTE_XY:
idx_xy = ti * N_AXIS + tj
c0 *= tl.load(scratch_ptr + ct_base + 0 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
c1 *= tl.load(scratch_ptr + ct_base + 1 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
c2 *= tl.load(scratch_ptr + ct_base + 2 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
c3 *= tl.load(scratch_ptr + ct_base + 3 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
c4 *= tl.load(scratch_ptr + ct_base + 4 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
ct_base = ct_base + 5 * N_AXIS * N_AXIS
if COMPUTE_XZ:
idx_xz = ti * N_AXIS + tk
c0 *= tl.load(scratch_ptr + ct_base + 0 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
c1 *= tl.load(scratch_ptr + ct_base + 1 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
c2 *= tl.load(scratch_ptr + ct_base + 2 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
c3 *= tl.load(scratch_ptr + ct_base + 3 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
c4 *= tl.load(scratch_ptr + ct_base + 4 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
ct_base = ct_base + 5 * N_AXIS * N_AXIS
if COMPUTE_YZ:
idx_yz = tj * N_AXIS + tk
c0 *= tl.load(scratch_ptr + ct_base + 0 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
c1 *= tl.load(scratch_ptr + ct_base + 1 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
c2 *= tl.load(scratch_ptr + ct_base + 2 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
c3 *= tl.load(scratch_ptr + ct_base + 3 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
c4 *= tl.load(scratch_ptr + ct_base + 4 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
density = c0 + c1 + c2 + c3 + c4
# PBC-wrapped grid index
gi = (cix + si) % nx
gi = tl.where(gi < 0, gi + nx, gi)
gj = (ciy + sj) % ny
gj = tl.where(gj < 0, gj + ny, gj)
gk = (ciz + sk) % nz
gk = tl.where(gk < 0, gk + nz, gk)
flat_idx = (gi.to(tl.int64) * ny + gj.to(tl.int64)) * nz + gk.to(tl.int64)
tl.atomic_add(density_map_ptr + flat_idx, density, mask=mask)
# =============================================================================
# Backward kernel
# =============================================================================
@triton.jit
def _separable_bwd_kernel(
# Forward inputs (read-only)
grad_density_map_ptr, # (nx*ny*nz,) float32
xyz_ptr,
b_ptr,
A_ptr,
B_ptr,
occ_ptr,
offsets_ptr,
inv_frac_ptr,
scratch_ptr,
# Metric tensor
G11, G22, G33,
G12, G13, G23,
# Grid params
inv_grid_x, inv_grid_y, inv_grid_z,
# Gradient outputs (pointers before constexpr)
grad_frac_ptr, # (N_atoms, 3) float32
grad_b_ptr, # (N_atoms,) float32
grad_occ_ptr, # (N_atoms,) float32
# Constexpr
nx: tl.constexpr,
ny: tl.constexpr,
nz: tl.constexpr,
N_sphere: tl.constexpr,
N_AXIS: tl.constexpr,
half_n: tl.constexpr,
SCRATCH_PER_ATOM: tl.constexpr,
BLOCK_V: tl.constexpr,
COMPUTE_XY: tl.constexpr,
COMPUTE_XZ: tl.constexpr,
COMPUTE_YZ: tl.constexpr,
STORE_CROSS_TABLES: tl.constexpr,
):
"""One program per atom. Recomputes 1D tables, accumulates gradients."""
atom = tl.program_id(0)
# ---- Stage 1: Load & compute (identical to forward) ----
b_iso = tl.load(b_ptr + atom)
occ = tl.load(occ_ptr + atom)
ax = tl.load(xyz_ptr + atom * 3 + 0)
ay = tl.load(xyz_ptr + atom * 3 + 1)
az = tl.load(xyz_ptr + atom * 3 + 2)
A0 = tl.load(A_ptr + atom * 5 + 0)
A1 = tl.load(A_ptr + atom * 5 + 1)
A2 = tl.load(A_ptr + atom * 5 + 2)
A3 = tl.load(A_ptr + atom * 5 + 3)
A4 = tl.load(A_ptr + atom * 5 + 4)
B0 = tl.load(B_ptr + atom * 5 + 0)
B1 = tl.load(B_ptr + atom * 5 + 1)
B2 = tl.load(B_ptr + atom * 5 + 2)
B3 = tl.load(B_ptr + atom * 5 + 3)
B4 = tl.load(B_ptr + atom * 5 + 4)
Bt0 = tl.maximum((B0 + b_iso) * 0.25, 0.1)
Bt1 = tl.maximum((B1 + b_iso) * 0.25, 0.1)
Bt2 = tl.maximum((B2 + b_iso) * 0.25, 0.1)
Bt3 = tl.maximum((B3 + b_iso) * 0.25, 0.1)
Bt4 = tl.maximum((B4 + b_iso) * 0.25, 0.1)
clamp0 = ((B0 + b_iso) * 0.25 > 0.1).to(tl.float32)
clamp1 = ((B1 + b_iso) * 0.25 > 0.1).to(tl.float32)
clamp2 = ((B2 + b_iso) * 0.25 > 0.1).to(tl.float32)
clamp3 = ((B3 + b_iso) * 0.25 > 0.1).to(tl.float32)
clamp4 = ((B4 + b_iso) * 0.25 > 0.1).to(tl.float32)
pi_sq: tl.constexpr = 9.869604401089358
pi_1p5: tl.constexpr = 5.568327996831708
al0 = pi_sq / Bt0
al1 = pi_sq / Bt1
al2 = pi_sq / Bt2
al3 = pi_sq / Bt3
al4 = pi_sq / Bt4
An0 = A0 * occ * pi_1p5 / (Bt0 * tl.sqrt(Bt0))
An1 = A1 * occ * pi_1p5 / (Bt1 * tl.sqrt(Bt1))
An2 = A2 * occ * pi_1p5 / (Bt2 * tl.sqrt(Bt2))
An3 = A3 * occ * pi_1p5 / (Bt3 * tl.sqrt(Bt3))
An4 = A4 * occ * pi_1p5 / (Bt4 * tl.sqrt(Bt4))
# Frac conversion
if0 = tl.load(inv_frac_ptr + 0)
if1 = tl.load(inv_frac_ptr + 1)
if2 = tl.load(inv_frac_ptr + 2)
if3 = tl.load(inv_frac_ptr + 3)
if4 = tl.load(inv_frac_ptr + 4)
if5 = tl.load(inv_frac_ptr + 5)
if6 = tl.load(inv_frac_ptr + 6)
if7 = tl.load(inv_frac_ptr + 7)
if8 = tl.load(inv_frac_ptr + 8)
frac_x = ax * if0 + ay * if1 + az * if2
frac_y = ax * if3 + ay * if4 + az * if5
frac_z = ax * if6 + ay * if7 + az * if8
frac_x = frac_x - tl.extra.cuda.libdevice.floor(frac_x)
frac_y = frac_y - tl.extra.cuda.libdevice.floor(frac_y)
frac_z = frac_z - tl.extra.cuda.libdevice.floor(frac_z)
cix = tl.extra.cuda.libdevice.round(frac_x * nx).to(tl.int32)
ciy = tl.extra.cuda.libdevice.round(frac_y * ny).to(tl.int32)
ciz = tl.extra.cuda.libdevice.round(frac_z * nz).to(tl.int32)
sub_x = frac_x - cix.to(tl.float32) * inv_grid_x
sub_y = frac_y - ciy.to(tl.float32) * inv_grid_y
sub_z = frac_z - ciz.to(tl.float32) * inv_grid_z
# ---- Stage 2-3: Rebuild tables ----
# Combined-exponent path only reads deltas; skip 1D exp tables.
_USE_COMBINED: tl.constexpr = (
not STORE_CROSS_TABLES and (COMPUTE_XY or COMPUTE_XZ or COMPUTE_YZ)
)
base = atom * SCRATCH_PER_ATOM
axis_idx = tl.arange(0, N_AXIS)
half_n_f: tl.constexpr = half_n
delta_x_vec = (axis_idx.to(tl.float32) - half_n_f) * inv_grid_x - sub_x
tl.store(scratch_ptr + base + 15 * N_AXIS + axis_idx, delta_x_vec)
if not _USE_COMBINED:
dx2 = delta_x_vec * delta_x_vec
tl.store(scratch_ptr + base + 0 * N_AXIS + axis_idx, tl.exp(-al0 * G11 * dx2))
tl.store(scratch_ptr + base + 1 * N_AXIS + axis_idx, tl.exp(-al1 * G11 * dx2))
tl.store(scratch_ptr + base + 2 * N_AXIS + axis_idx, tl.exp(-al2 * G11 * dx2))
tl.store(scratch_ptr + base + 3 * N_AXIS + axis_idx, tl.exp(-al3 * G11 * dx2))
tl.store(scratch_ptr + base + 4 * N_AXIS + axis_idx, tl.exp(-al4 * G11 * dx2))
delta_y_vec = (axis_idx.to(tl.float32) - half_n_f) * inv_grid_y - sub_y
tl.store(scratch_ptr + base + 16 * N_AXIS + axis_idx, delta_y_vec)
if not _USE_COMBINED:
dy2 = delta_y_vec * delta_y_vec
tl.store(scratch_ptr + base + 5 * N_AXIS + axis_idx, tl.exp(-al0 * G22 * dy2))
tl.store(scratch_ptr + base + 6 * N_AXIS + axis_idx, tl.exp(-al1 * G22 * dy2))
tl.store(scratch_ptr + base + 7 * N_AXIS + axis_idx, tl.exp(-al2 * G22 * dy2))
tl.store(scratch_ptr + base + 8 * N_AXIS + axis_idx, tl.exp(-al3 * G22 * dy2))
tl.store(scratch_ptr + base + 9 * N_AXIS + axis_idx, tl.exp(-al4 * G22 * dy2))
delta_z_vec = (axis_idx.to(tl.float32) - half_n_f) * inv_grid_z - sub_z
tl.store(scratch_ptr + base + 17 * N_AXIS + axis_idx, delta_z_vec)
if not _USE_COMBINED:
dz2 = delta_z_vec * delta_z_vec
tl.store(scratch_ptr + base + 10 * N_AXIS + axis_idx, tl.exp(-al0 * G33 * dz2))
tl.store(scratch_ptr + base + 11 * N_AXIS + axis_idx, tl.exp(-al1 * G33 * dz2))
tl.store(scratch_ptr + base + 12 * N_AXIS + axis_idx, tl.exp(-al2 * G33 * dz2))
tl.store(scratch_ptr + base + 13 * N_AXIS + axis_idx, tl.exp(-al3 * G33 * dz2))
tl.store(scratch_ptr + base + 14 * N_AXIS + axis_idx, tl.exp(-al4 * G33 * dz2))
# Rebuild cross-term tables (same as forward)
cross_base = base + 18 * N_AXIS
if STORE_CROSS_TABLES:
if COMPUTE_XY:
idx_2d = tl.arange(0, N_AXIS * N_AXIS)
ii = idx_2d // N_AXIS
jj = idx_2d % N_AXIS
dx_i = (ii.to(tl.float32) - half_n_f) * inv_grid_x - sub_x
dy_j = (jj.to(tl.float32) - half_n_f) * inv_grid_y - sub_y
prod_xy = dx_i * dy_j
off_xy = cross_base
tl.store(scratch_ptr + off_xy + 0 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al0 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 1 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al1 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 2 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al2 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 3 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al3 * 2.0 * G12 * prod_xy))
tl.store(scratch_ptr + off_xy + 4 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al4 * 2.0 * G12 * prod_xy))
cross_base = cross_base + 5 * N_AXIS * N_AXIS
if COMPUTE_XZ:
idx_2d = tl.arange(0, N_AXIS * N_AXIS)
ii = idx_2d // N_AXIS
kk = idx_2d % N_AXIS
dx_i = (ii.to(tl.float32) - half_n_f) * inv_grid_x - sub_x
dz_k = (kk.to(tl.float32) - half_n_f) * inv_grid_z - sub_z
prod_xz = dx_i * dz_k
off_xz = cross_base
tl.store(scratch_ptr + off_xz + 0 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al0 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 1 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al1 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 2 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al2 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 3 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al3 * 2.0 * G13 * prod_xz))
tl.store(scratch_ptr + off_xz + 4 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al4 * 2.0 * G13 * prod_xz))
cross_base = cross_base + 5 * N_AXIS * N_AXIS
if COMPUTE_YZ:
idx_2d = tl.arange(0, N_AXIS * N_AXIS)
jj = idx_2d // N_AXIS
kk = idx_2d % N_AXIS
dy_j = (jj.to(tl.float32) - half_n_f) * inv_grid_y - sub_y
dz_k = (kk.to(tl.float32) - half_n_f) * inv_grid_z - sub_z
prod_yz = dy_j * dz_k
off_yz = cross_base
tl.store(scratch_ptr + off_yz + 0 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al0 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 1 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al1 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 2 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al2 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 3 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al3 * 2.0 * G23 * prod_yz))
tl.store(scratch_ptr + off_yz + 4 * N_AXIS * N_AXIS + idx_2d,
tl.exp(-al4 * 2.0 * G23 * prod_yz))
# ---- Stage 4: Gradient accumulation ----
cross_off_first = base + 18 * N_AXIS
# Init accumulators with matching dtype (supports float32 and float64)
_zero = b_iso * 0.0
g_fx = _zero
g_fy = _zero
g_fz = _zero
g_b = _zero
g_occ = _zero
v_offsets = tl.arange(0, BLOCK_V)
for v_start in range(0, N_sphere, BLOCK_V):
v = v_start + v_offsets
mask = v < N_sphere
si = tl.load(offsets_ptr + v * 3 + 0, mask=mask, other=0).to(tl.int32)
sj = tl.load(offsets_ptr + v * 3 + 1, mask=mask, other=0).to(tl.int32)
sk = tl.load(offsets_ptr + v * 3 + 2, mask=mask, other=0).to(tl.int32)
ti = si + half_n
tj = sj + half_n
tk = sk + half_n
# Load fractional deltas once (used for rho in path a, gradients always)
dxi = tl.load(scratch_ptr + base + 15 * N_AXIS + ti, mask=mask, other=0.0)
dyj = tl.load(scratch_ptr + base + 16 * N_AXIS + tj, mask=mask, other=0.0)
dzk = tl.load(scratch_ptr + base + 17 * N_AXIS + tk, mask=mask, other=0.0)
# Compute rho0-rho4 (same two-path strategy as forward)
if _USE_COMBINED:
# Path (a): combined exponent — single exp(-alpha*r²) per component
r_sq = G11 * dxi * dxi + G22 * dyj * dyj + G33 * dzk * dzk
if COMPUTE_XY:
r_sq = r_sq + 2.0 * G12 * dxi * dyj
if COMPUTE_XZ:
r_sq = r_sq + 2.0 * G13 * dxi * dzk
if COMPUTE_YZ:
r_sq = r_sq + 2.0 * G23 * dyj * dzk
rho0 = An0 * tl.exp(-al0 * r_sq)
rho1 = An1 * tl.exp(-al1 * r_sq)
rho2 = An2 * tl.exp(-al2 * r_sq)
rho3 = An3 * tl.exp(-al3 * r_sq)
rho4 = An4 * tl.exp(-al4 * r_sq)
else:
# Path (b): separable 1D diagonal tables
vx0 = tl.load(scratch_ptr + base + 0 * N_AXIS + ti, mask=mask, other=0.0)
vy0 = tl.load(scratch_ptr + base + 5 * N_AXIS + tj, mask=mask, other=0.0)
vz0 = tl.load(scratch_ptr + base + 10 * N_AXIS + tk, mask=mask, other=0.0)
vx1 = tl.load(scratch_ptr + base + 1 * N_AXIS + ti, mask=mask, other=0.0)
vy1 = tl.load(scratch_ptr + base + 6 * N_AXIS + tj, mask=mask, other=0.0)
vz1 = tl.load(scratch_ptr + base + 11 * N_AXIS + tk, mask=mask, other=0.0)
vx2 = tl.load(scratch_ptr + base + 2 * N_AXIS + ti, mask=mask, other=0.0)
vy2 = tl.load(scratch_ptr + base + 7 * N_AXIS + tj, mask=mask, other=0.0)
vz2 = tl.load(scratch_ptr + base + 12 * N_AXIS + tk, mask=mask, other=0.0)
vx3 = tl.load(scratch_ptr + base + 3 * N_AXIS + ti, mask=mask, other=0.0)
vy3 = tl.load(scratch_ptr + base + 8 * N_AXIS + tj, mask=mask, other=0.0)
vz3 = tl.load(scratch_ptr + base + 13 * N_AXIS + tk, mask=mask, other=0.0)
vx4 = tl.load(scratch_ptr + base + 4 * N_AXIS + ti, mask=mask, other=0.0)
vy4 = tl.load(scratch_ptr + base + 9 * N_AXIS + tj, mask=mask, other=0.0)
vz4 = tl.load(scratch_ptr + base + 14 * N_AXIS + tk, mask=mask, other=0.0)
rho0 = An0 * vx0 * vy0 * vz0
rho1 = An1 * vx1 * vy1 * vz1
rho2 = An2 * vx2 * vy2 * vz2
rho3 = An3 * vx3 * vy3 * vz3
rho4 = An4 * vx4 * vy4 * vz4
if STORE_CROSS_TABLES:
ct_base = cross_off_first
if COMPUTE_XY:
idx_xy = ti * N_AXIS + tj
rho0 *= tl.load(scratch_ptr + ct_base + 0 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
rho1 *= tl.load(scratch_ptr + ct_base + 1 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
rho2 *= tl.load(scratch_ptr + ct_base + 2 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
rho3 *= tl.load(scratch_ptr + ct_base + 3 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
rho4 *= tl.load(scratch_ptr + ct_base + 4 * N_AXIS * N_AXIS + idx_xy,
mask=mask, other=1.0)
ct_base = ct_base + 5 * N_AXIS * N_AXIS
if COMPUTE_XZ:
idx_xz = ti * N_AXIS + tk
rho0 *= tl.load(scratch_ptr + ct_base + 0 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
rho1 *= tl.load(scratch_ptr + ct_base + 1 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
rho2 *= tl.load(scratch_ptr + ct_base + 2 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
rho3 *= tl.load(scratch_ptr + ct_base + 3 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
rho4 *= tl.load(scratch_ptr + ct_base + 4 * N_AXIS * N_AXIS + idx_xz,
mask=mask, other=1.0)
ct_base = ct_base + 5 * N_AXIS * N_AXIS
if COMPUTE_YZ:
idx_yz = tj * N_AXIS + tk
rho0 *= tl.load(scratch_ptr + ct_base + 0 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
rho1 *= tl.load(scratch_ptr + ct_base + 1 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
rho2 *= tl.load(scratch_ptr + ct_base + 2 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
rho3 *= tl.load(scratch_ptr + ct_base + 3 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
rho4 *= tl.load(scratch_ptr + ct_base + 4 * N_AXIS * N_AXIS + idx_yz,
mask=mask, other=1.0)
# r² from deltas (path b only — path a already has r_sq)
r_sq = G11 * dxi * dxi + G22 * dyj * dyj + G33 * dzk * dzk
if COMPUTE_XY:
r_sq = r_sq + 2.0 * G12 * dxi * dyj
if COMPUTE_XZ:
r_sq = r_sq + 2.0 * G13 * dxi * dzk
if COMPUTE_YZ:
r_sq = r_sq + 2.0 * G23 * dyj * dzk
# ---- Gather upstream gradient ----
gi = (cix + si) % nx
gi = tl.where(gi < 0, gi + nx, gi)
gj = (ciy + sj) % ny
gj = tl.where(gj < 0, gj + ny, gj)
gk = (ciz + sk) % nz
gk = tl.where(gk < 0, gk + nz, gk)
flat_idx = (gi.to(tl.int64) * ny + gj.to(tl.int64)) * nz + gk.to(tl.int64)
grad_out = tl.load(grad_density_map_ptr + flat_idx, mask=mask, other=0.0)
# ---- Position gradient (fractional) ----
# d(rho_c)/d(frac_x) = rho_c * 2*alpha_c * (G11*dx + G12*dy + G13*dz)
# (positive sign: d(delta)/d(frac) = -1 combined with -alpha in exponent)
dr_dx = G11 * dxi
dr_dy = G22 * dyj
dr_dz = G33 * dzk
if COMPUTE_XY:
dr_dx = dr_dx + G12 * dyj
dr_dy = dr_dy + G12 * dxi
if COMPUTE_XZ:
dr_dx = dr_dx + G13 * dzk
dr_dz = dr_dz + G13 * dxi
if COMPUTE_YZ:
dr_dy = dr_dy + G23 * dzk
dr_dz = dr_dz + G23 * dyj
coeff_pos = 2.0 * (al0 * rho0 + al1 * rho1 + al2 * rho2
+ al3 * rho3 + al4 * rho4)
scale_pos = grad_out * coeff_pos
g_fx += tl.sum(tl.where(mask, scale_pos * dr_dx, 0.0), axis=0)
g_fy += tl.sum(tl.where(mask, scale_pos * dr_dy, 0.0), axis=0)
g_fz += tl.sum(tl.where(mask, scale_pos * dr_dz, 0.0), axis=0)
# ---- B-factor gradient ----
# r_sq already computed above (in path a: during rho, in path b: after rho)
db0 = rho0 * (-1.5 / Bt0 + al0 * r_sq / Bt0) * clamp0
db1 = rho1 * (-1.5 / Bt1 + al1 * r_sq / Bt1) * clamp1
db2 = rho2 * (-1.5 / Bt2 + al2 * r_sq / Bt2) * clamp2
db3 = rho3 * (-1.5 / Bt3 + al3 * r_sq / Bt3) * clamp3
db4 = rho4 * (-1.5 / Bt4 + al4 * r_sq / Bt4) * clamp4
g_b += tl.sum(tl.where(mask, grad_out * 0.25 * (db0 + db1 + db2 + db3 + db4),
0.0), axis=0)
# ---- Occupancy gradient ----
density = rho0 + rho1 + rho2 + rho3 + rho4
g_occ += tl.sum(tl.where(mask,
grad_out * tl.where(occ != 0.0, density / occ, 0.0),
0.0), axis=0)
# Write accumulated gradients
tl.store(grad_frac_ptr + atom * 3 + 0, g_fx)
tl.store(grad_frac_ptr + atom * 3 + 1, g_fy)
tl.store(grad_frac_ptr + atom * 3 + 2, g_fz)
tl.store(grad_b_ptr + atom, g_b)
tl.store(grad_occ_ptr + atom, g_occ)
# =============================================================================
# Python helpers
# =============================================================================
_config_cache: dict = {}
def _get_cached_config(frac_matrix, grid_shape, radius_angstrom, device):
"""Return all precomputed config, sphere offsets, and sizing.
Cached by (frac_matrix, grid_shape, radius) to avoid recomputation.
All GPU tensors are allocated once and reused.
All .item() / .cpu() calls happen only on the first call.
"""
fm_key = tuple(frac_matrix.cpu().flatten().tolist())
key = (fm_key, grid_shape, radius_angstrom, str(device))
if key in _config_cache:
return _config_cache[key]
# --- Metric tensor (CPU computation, no GPU syncs) ---
G = frac_matrix.T @ frac_matrix
G_vals = [G[0, 0].item(), G[1, 1].item(), G[2, 2].item(),
G[0, 1].item(), G[0, 2].item(), G[1, 2].item()]
G_diag_norm = math.sqrt(G_vals[0]**2 + G_vals[1]**2 + G_vals[2]**2)
tol = 1e-3 * G_diag_norm
compute_xy = abs(G_vals[3]) > tol
compute_xz = abs(G_vals[4]) > tol
compute_yz = abs(G_vals[5]) > tol
# Always recompute cross-terms on-the-fly (pure ALU, no memory traffic).
# Storing 2D tables in scratch adds 5*N_AXIS² global memory reads/writes
# per atom which is slower than a few extra exp() calls.
store_cross_tables = False
inv_grid_vals = [1.0 / grid_shape[0], 1.0 / grid_shape[1],
1.0 / grid_shape[2]]
# --- GPU tensors (allocated once) ---
G_flat = torch.tensor(G_vals, device=device, dtype=torch.float32)
inv_grid = torch.tensor(inv_grid_vals, device=device, dtype=torch.float32)
# --- Sphere offsets ---
cell_lengths = [math.sqrt(G_vals[i]) for i in range(3)]
max_offsets = [
int(math.ceil(radius_angstrom / (inv_grid_vals[d] * cell_lengths[d]))) + 1
for d in range(3)
]
ranges = [torch.arange(-m, m + 1, device=device) for m in max_offsets]
gx, gy, gz = torch.meshgrid(*ranges, indexing="ij")
coords = torch.stack((gx, gy, gz), dim=-1)
delta_frac = coords.float() * inv_grid
r_sq = torch.einsum("...i,ij,...j->...", delta_frac, G.to(device), delta_frac)
sphere_offsets = coords[r_sq <= radius_angstrom**2].to(torch.int16).contiguous()
# --- N_AXIS, half_n ---
min_voxel_size = min(inv_grid_vals[d] * cell_lengths[d] for d in range(3))
half_n = int(math.ceil(radius_angstrom / min_voxel_size))
N_AXIS = triton.next_power_of_2(2 * half_n + 1)
# --- Scratch sizing: 1D tables (3 axes × 5 components) + deltas (3 axes) ---
base_scratch = 18 * N_AXIS
# --- BLOCK_V ---
N_sphere = sphere_offsets.shape[0]
BLOCK_V = triton.next_power_of_2(min(N_sphere, 512))
# Backward kernel: BLOCK_V=256 + num_warps=2 is universally optimal.
# The smaller tile size reduces register pressure (146→fewer per-thread
# vectors), allowing 2 warps to distribute across more blocks per SM.
bwd_num_warps = 2
bwd_BLOCK_V = min(256, triton.next_power_of_2(N_sphere))
config = {
"G_flat": G_flat,
"G_vals": G_vals,
"inv_grid": inv_grid,
"inv_grid_vals": inv_grid_vals,
"compute_xy": compute_xy,
"compute_xz": compute_xz,
"compute_yz": compute_yz,
"store_cross_tables": store_cross_tables,
"sphere_offsets": sphere_offsets,
"N_sphere": N_sphere,
"N_AXIS": N_AXIS,
"half_n": half_n,
"base_scratch": base_scratch,
"BLOCK_V": BLOCK_V,
"bwd_num_warps": bwd_num_warps,
"bwd_BLOCK_V": bwd_BLOCK_V,
}
# Warmup: launch a 1-atom kernel to force Triton JIT compilation.
# The first compilation can produce incorrect results due to a known
# Triton JIT artifact; this disposable launch ensures the compiled
# kernel is correct before any real data is processed.
_warmup_kernel(config, grid_shape, device)
_config_cache[key] = config
return config
def _warmup_kernel(cfg, grid_shape, device):
"""Launch a disposable 1-atom forward kernel to force Triton compilation."""
nx, ny, nz = grid_shape
_separable_fwd_kernel[(1,)](
torch.zeros(nx * ny * nz, device=device, dtype=torch.float32),
torch.zeros(1, 3, device=device, dtype=torch.float32),
torch.zeros(1, device=device, dtype=torch.float32),
torch.zeros(1, 5, device=device, dtype=torch.float32),
torch.zeros(1, 5, device=device, dtype=torch.float32),
torch.ones(1, device=device, dtype=torch.float32),
cfg["sphere_offsets"],
torch.eye(3, device=device, dtype=torch.float32).view(-1),
torch.zeros(1, cfg["base_scratch"], device=device, dtype=torch.float32).view(-1),
cfg["G_vals"][0], cfg["G_vals"][1], cfg["G_vals"][2],
cfg["G_vals"][3], cfg["G_vals"][4], cfg["G_vals"][5],
cfg["inv_grid_vals"][0], cfg["inv_grid_vals"][1], cfg["inv_grid_vals"][2],
nx=nx, ny=ny, nz=nz,
N_sphere=cfg["N_sphere"], N_AXIS=cfg["N_AXIS"], half_n=cfg["half_n"],
SCRATCH_PER_ATOM=cfg["base_scratch"], BLOCK_V=cfg["BLOCK_V"],
COMPUTE_XY=cfg["compute_xy"], COMPUTE_XZ=cfg["compute_xz"],
COMPUTE_YZ=cfg["compute_yz"], STORE_CROSS_TABLES=cfg["store_cross_tables"],
)
torch.cuda.synchronize()
# Scratch buffer cache: reuse if large enough
_scratch_buf: Optional[torch.Tensor] = None
# Track whether the forward kernel has been JIT-compiled.
# The first Triton JIT compilation can produce different results
# (likely due to uninitialized state during compilation), so we
# discard the first call and re-run.
_fwd_kernel_warmed_up: bool = False
# =============================================================================
# Autograd wrapper
# =============================================================================
class _SeparableDensityFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
density_map, # (nx, ny, nz)
xyz, # (N_atoms, 3)
b, # (N_atoms,)
A, # (N_atoms, 5)
B, # (N_atoms, 5)
occ, # (N_atoms,)
inv_frac_matrix, # (3, 3)
sphere_offsets, # (N_sphere, 3) int16
G_flat, # (6,) float32
inv_grid, # (3,) float32
scratch, # (N_atoms, scratch_per_atom) float32
N_AXIS, # int
half_n, # int
compute_xy, # bool
compute_xz, # bool
compute_yz, # bool
store_cross_tables, # bool
# Pre-extracted Python floats (no .item() calls needed)
G_vals, # list of 6 Python floats
inv_grid_vals, # list of 3 Python floats
BLOCK_V, # int
bwd_num_warps, # int
bwd_BLOCK_V, # int
):
N_atoms = xyz.shape[0]
nx, ny, nz = density_map.shape
N_sphere = sphere_offsets.shape[0]
scratch_per_atom = scratch.shape[1]
# Ensure contiguous
xyz = xyz.contiguous()
b = b.contiguous()
A = A.contiguous()
B = B.contiguous()
occ = occ.contiguous()
inv_frac_flat = inv_frac_matrix.contiguous().view(-1)
output = density_map.clone()
_separable_fwd_kernel[(N_atoms,)](
output.view(-1), xyz, b, A, B, occ,
sphere_offsets, inv_frac_flat, scratch.view(-1),
G_vals[0], G_vals[1], G_vals[2],
G_vals[3], G_vals[4], G_vals[5],
inv_grid_vals[0], inv_grid_vals[1], inv_grid_vals[2],
nx=nx, ny=ny, nz=nz,
N_sphere=N_sphere, N_AXIS=N_AXIS, half_n=half_n,
SCRATCH_PER_ATOM=scratch_per_atom, BLOCK_V=BLOCK_V,
COMPUTE_XY=compute_xy, COMPUTE_XZ=compute_xz,
COMPUTE_YZ=compute_yz, STORE_CROSS_TABLES=store_cross_tables,
)
ctx.save_for_backward(
xyz, b, A, B, occ, inv_frac_matrix,
sphere_offsets, G_flat, inv_grid, scratch,
)
ctx.grid_shape = (nx, ny, nz)
ctx.N_AXIS = N_AXIS
ctx.half_n = half_n
ctx.compute_xy = compute_xy
ctx.compute_xz = compute_xz
ctx.compute_yz = compute_yz
ctx.store_cross_tables = store_cross_tables
ctx.G_vals = G_vals
ctx.inv_grid_vals = inv_grid_vals
ctx.BLOCK_V = BLOCK_V
ctx.bwd_num_warps = bwd_num_warps
ctx.bwd_BLOCK_V = bwd_BLOCK_V
return output
@staticmethod
def backward(ctx, grad_density_map):
(xyz, b, A, B, occ, inv_frac_matrix,
sphere_offsets, G_flat, inv_grid, scratch) = ctx.saved_tensors
nx, ny, nz = ctx.grid_shape
N_atoms = xyz.shape[0]
N_sphere = sphere_offsets.shape[0]
scratch_per_atom = scratch.shape[1]
G_vals = ctx.G_vals
inv_grid_vals = ctx.inv_grid_vals
grad_density_map = grad_density_map.contiguous()
inv_frac_flat = inv_frac_matrix.contiguous().view(-1)
grad_frac = torch.zeros(N_atoms, 3, device=xyz.device, dtype=xyz.dtype)
grad_b = torch.zeros_like(b)
grad_occ = torch.zeros_like(occ)
_separable_bwd_kernel[(N_atoms,)](
grad_density_map.view(-1),
xyz, b, A, B, occ,
sphere_offsets, inv_frac_flat, scratch.view(-1),
G_vals[0], G_vals[1], G_vals[2],
G_vals[3], G_vals[4], G_vals[5],
inv_grid_vals[0], inv_grid_vals[1], inv_grid_vals[2],
grad_frac, grad_b, grad_occ,
nx=nx, ny=ny, nz=nz,
N_sphere=N_sphere, N_AXIS=ctx.N_AXIS, half_n=ctx.half_n,
SCRATCH_PER_ATOM=scratch_per_atom, BLOCK_V=ctx.bwd_BLOCK_V,
COMPUTE_XY=ctx.compute_xy, COMPUTE_XZ=ctx.compute_xz,
COMPUTE_YZ=ctx.compute_yz,
STORE_CROSS_TABLES=ctx.store_cross_tables,
num_warps=ctx.bwd_num_warps,
)
# frac = xyz @ inv_frac.T → grad_xyz = grad_frac @ inv_frac
grad_xyz = grad_frac @ inv_frac_matrix
# Return gradients matching forward arg order (22 args total)
return (None, grad_xyz, grad_b, None, None, grad_occ, None,
None, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None)
# =============================================================================
# Public API
# =============================================================================
[docs]
def separable_density_gpu(
density_map: torch.Tensor,
xyz: torch.Tensor,
b: torch.Tensor,
inv_frac_matrix: torch.Tensor,
frac_matrix: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
occ: torch.Tensor,
radius_angstrom: float,
) -> torch.Tensor:
"""Separable Gaussian density splatting on GPU via Triton.
Eliminates the real_space_grid tensor and PBC matrix operations by
working directly in fractional space with the metric tensor. Precomputes
1D Gaussian tables per atom and gathers per sphere voxel.
Parameters
----------
density_map : (nx, ny, nz) — density grid to update (not modified in-place)
xyz : (N_atoms, 3) — Cartesian positions
b : (N_atoms,) — isotropic B-factors
inv_frac_matrix : (3, 3) — Cartesian→fractional
frac_matrix : (3, 3) — fractional→Cartesian
A : (N_atoms, 5) — ITC92 amplitudes
B : (N_atoms, 5) — ITC92 widths
occ : (N_atoms,) — occupancies
radius_angstrom : float — cutoff radius
Returns
-------
torch.Tensor — updated density map
"""
global _scratch_buf, _fwd_kernel_warmed_up
device = density_map.device
grid_shape = density_map.shape
# All config is cached — no .item() or .cpu() calls after first call
cfg = _get_cached_config(frac_matrix, grid_shape, radius_angstrom, device)
N_atoms = xyz.shape[0]
needed = N_atoms * cfg["base_scratch"]
# Reuse scratch buffer if large enough, otherwise allocate
if _scratch_buf is None or _scratch_buf.numel() < needed or \
_scratch_buf.device != device:
_scratch_buf = torch.zeros(needed, device=device, dtype=torch.float32)
scratch = _scratch_buf[:needed].view(N_atoms, cfg["base_scratch"])
def _run():
return _SeparableDensityFunction.apply(
density_map, xyz, b, A, B, occ, inv_frac_matrix,
cfg["sphere_offsets"], cfg["G_flat"], cfg["inv_grid"], scratch,
cfg["N_AXIS"], cfg["half_n"],
cfg["compute_xy"], cfg["compute_xz"], cfg["compute_yz"],
cfg["store_cross_tables"],
cfg["G_vals"], cfg["inv_grid_vals"], cfg["BLOCK_V"],
cfg["bwd_num_warps"], cfg["bwd_BLOCK_V"],
)
if not _fwd_kernel_warmed_up:
# First Triton JIT compilation produces unreliable results.
# Run once to compile, discard, then re-run for correct output.
_run()
_fwd_kernel_warmed_up = True
return _run()