Source code for torchref.base.kernels.triton_kernel

"""
Triton GPU kernel for fused electron density computation.

Fuses PBC wrapping, r² calculation, 5-Gaussian evaluation, and scatter-add
into a single GPU kernel, eliminating ~14 separate kernel launches and
~500MB of intermediate memory allocations.

Provides full autograd support for refinement of xyz, b, and occ.
"""

import torch
import triton
import triton.language as tl

# =============================================================================
# Constants
# =============================================================================

PI: float = 3.141592653589793
PI_SQ: float = PI * PI
PI_1P5: float = PI * 1.7724538509055159  # pi * sqrt(pi) = pi^1.5


# =============================================================================
# Forward kernel
# =============================================================================

@triton.jit
def _density_fwd_kernel(
    # Pointers
    surr_coords_ptr,   # (N_atoms, N_voxels, 3) float32
    voxel_idx_ptr,     # (N_atoms, N_voxels, 3) int32
    density_map_ptr,   # (nx*ny*nz,) float32
    xyz_ptr,           # (N_atoms, 3) float32
    b_ptr,             # (N_atoms,) float32
    inv_frac_ptr,      # (9,) float32 row-major
    frac_ptr,          # (9,) float32 row-major
    A_ptr,             # (N_atoms, 5) float32
    B_ptr,             # (N_atoms, 5) float32
    occ_ptr,           # (N_atoms,) float32
    # Dimensions
    N_voxels: tl.constexpr,
    ny: tl.constexpr,
    nz: tl.constexpr,
    # Block size
    BLOCK_V: tl.constexpr,
):
    """One program per atom. Threads process BLOCK_V voxels."""
    atom = tl.program_id(0)

    # Load per-atom scalars
    b_iso = tl.load(b_ptr + atom)
    occ = tl.load(occ_ptr + atom)
    ax = tl.load(xyz_ptr + atom * 3 + 0)
    ay = tl.load(xyz_ptr + atom * 3 + 1)
    az = tl.load(xyz_ptr + atom * 3 + 2)

    # Load ITC92 params (5 Gaussians)
    A0 = tl.load(A_ptr + atom * 5 + 0)
    A1 = tl.load(A_ptr + atom * 5 + 1)
    A2 = tl.load(A_ptr + atom * 5 + 2)
    A3 = tl.load(A_ptr + atom * 5 + 3)
    A4 = tl.load(A_ptr + atom * 5 + 4)

    B0 = tl.load(B_ptr + atom * 5 + 0)
    B1 = tl.load(B_ptr + atom * 5 + 1)
    B2 = tl.load(B_ptr + atom * 5 + 2)
    B3 = tl.load(B_ptr + atom * 5 + 3)
    B4 = tl.load(B_ptr + atom * 5 + 4)

    # Precompute B_total and A_norm for each Gaussian
    Bt0 = tl.maximum((B0 + b_iso) * 0.25, 0.1)
    Bt1 = tl.maximum((B1 + b_iso) * 0.25, 0.1)
    Bt2 = tl.maximum((B2 + b_iso) * 0.25, 0.1)
    Bt3 = tl.maximum((B3 + b_iso) * 0.25, 0.1)
    Bt4 = tl.maximum((B4 + b_iso) * 0.25, 0.1)

    pi_1p5: tl.constexpr = 5.568327996831708  # pi^1.5
    An0 = A0 * occ * pi_1p5 / (Bt0 * tl.sqrt(Bt0))
    An1 = A1 * occ * pi_1p5 / (Bt1 * tl.sqrt(Bt1))
    An2 = A2 * occ * pi_1p5 / (Bt2 * tl.sqrt(Bt2))
    An3 = A3 * occ * pi_1p5 / (Bt3 * tl.sqrt(Bt3))
    An4 = A4 * occ * pi_1p5 / (Bt4 * tl.sqrt(Bt4))

    # Load 3x3 matrices (row-major: M[i,j] = ptr[i*3+j])
    # inv_frac_matrix.T means we need columns of inv_frac = rows of inv_frac.T
    # diff @ inv_frac.T  =>  for each component: dot(diff, inv_frac.T[col]) = dot(diff, inv_frac[:,col])
    # In row-major, inv_frac[i,j] = inv_frac_ptr[i*3+j]
    # Column j of inv_frac = inv_frac_ptr[0*3+j], inv_frac_ptr[1*3+j], inv_frac_ptr[2*3+j]
    if0 = tl.load(inv_frac_ptr + 0); if1 = tl.load(inv_frac_ptr + 1); if2 = tl.load(inv_frac_ptr + 2)
    if3 = tl.load(inv_frac_ptr + 3); if4 = tl.load(inv_frac_ptr + 4); if5 = tl.load(inv_frac_ptr + 5)
    if6 = tl.load(inv_frac_ptr + 6); if7 = tl.load(inv_frac_ptr + 7); if8 = tl.load(inv_frac_ptr + 8)

    f0 = tl.load(frac_ptr + 0); f1 = tl.load(frac_ptr + 1); f2 = tl.load(frac_ptr + 2)
    f3 = tl.load(frac_ptr + 3); f4 = tl.load(frac_ptr + 4); f5 = tl.load(frac_ptr + 5)
    f6 = tl.load(frac_ptr + 6); f7 = tl.load(frac_ptr + 7); f8 = tl.load(frac_ptr + 8)

    pi_sq: tl.constexpr = 9.869604401089358

    # Process voxels in blocks
    v_offsets = tl.arange(0, BLOCK_V)
    base = atom * N_voxels

    for v_start in range(0, N_voxels, BLOCK_V):
        v = v_start + v_offsets
        mask = v < N_voxels

        # Load surrounding_coords[atom, v, 0:3]
        idx3 = (base + v) * 3
        sx = tl.load(surr_coords_ptr + idx3 + 0, mask=mask)
        sy = tl.load(surr_coords_ptr + idx3 + 1, mask=mask)
        sz = tl.load(surr_coords_ptr + idx3 + 2, mask=mask)

        # diff = surrounding - xyz
        dx = sx - ax
        dy = sy - ay
        dz = sz - az

        # PBC: diff_frac = diff @ inv_frac_matrix.T
        # (diff @ M.T)[i] = sum_j diff[j] * M[i,j]  (dot with row i of M)
        # Row-major: M[i,j] = ptr[i*3 + j]
        fx = dx * if0 + dy * if1 + dz * if2
        fy = dx * if3 + dy * if4 + dz * if5
        fz = dx * if6 + dy * if7 + dz * if8

        # Round and correct
        tx = tl.extra.cuda.libdevice.round(fx)
        ty = tl.extra.cuda.libdevice.round(fy)
        tz = tl.extra.cuda.libdevice.round(fz)

        # correction = translation @ frac_matrix.T
        cx = tx * f0 + ty * f1 + tz * f2
        cy = tx * f3 + ty * f4 + tz * f5
        cz = tx * f6 + ty * f7 + tz * f8

        # Wrapped diff
        wx = dx - cx
        wy = dy - cy
        wz = dz - cz

        r2 = wx * wx + wy * wy + wz * wz

        # 5-Gaussian density
        density = (
            An0 * tl.exp(-pi_sq * r2 / Bt0)
            + An1 * tl.exp(-pi_sq * r2 / Bt1)
            + An2 * tl.exp(-pi_sq * r2 / Bt2)
            + An3 * tl.exp(-pi_sq * r2 / Bt3)
            + An4 * tl.exp(-pi_sq * r2 / Bt4)
        )

        # Load voxel indices and compute flat index
        vi3 = (base + v) * 3
        ix = tl.load(voxel_idx_ptr + vi3 + 0, mask=mask).to(tl.int64)
        iy = tl.load(voxel_idx_ptr + vi3 + 1, mask=mask).to(tl.int64)
        iz = tl.load(voxel_idx_ptr + vi3 + 2, mask=mask).to(tl.int64)
        flat_idx = ix * (ny * nz) + iy * nz + iz

        # Atomic add to density map
        tl.atomic_add(density_map_ptr + flat_idx, density, mask=mask)


# =============================================================================
# Backward kernel
# =============================================================================

@triton.jit
def _density_bwd_kernel(
    # Forward inputs (read-only)
    surr_coords_ptr,
    voxel_idx_ptr,
    grad_density_map_ptr,  # (nx*ny*nz,) gradient from upstream
    xyz_ptr,
    b_ptr,
    inv_frac_ptr,
    frac_ptr,
    A_ptr,
    B_ptr,
    occ_ptr,
    # Gradient outputs (accumulated via atomic add)
    grad_xyz_ptr,          # (N_atoms, 3)
    grad_b_ptr,            # (N_atoms,)
    grad_occ_ptr,          # (N_atoms,)
    # Dimensions
    N_voxels: tl.constexpr,
    ny: tl.constexpr,
    nz: tl.constexpr,
    BLOCK_V: tl.constexpr,
):
    """One program per atom. Accumulates gradients across voxels."""
    atom = tl.program_id(0)

    # Load per-atom data (same as forward)
    b_iso = tl.load(b_ptr + atom)
    occ = tl.load(occ_ptr + atom)
    ax = tl.load(xyz_ptr + atom * 3 + 0)
    ay = tl.load(xyz_ptr + atom * 3 + 1)
    az = tl.load(xyz_ptr + atom * 3 + 2)

    A0 = tl.load(A_ptr + atom * 5 + 0)
    A1 = tl.load(A_ptr + atom * 5 + 1)
    A2 = tl.load(A_ptr + atom * 5 + 2)
    A3 = tl.load(A_ptr + atom * 5 + 3)
    A4 = tl.load(A_ptr + atom * 5 + 4)

    B0 = tl.load(B_ptr + atom * 5 + 0)
    B1 = tl.load(B_ptr + atom * 5 + 1)
    B2 = tl.load(B_ptr + atom * 5 + 2)
    B3 = tl.load(B_ptr + atom * 5 + 3)
    B4 = tl.load(B_ptr + atom * 5 + 4)

    Bt0 = tl.maximum((B0 + b_iso) * 0.25, 0.1)
    Bt1 = tl.maximum((B1 + b_iso) * 0.25, 0.1)
    Bt2 = tl.maximum((B2 + b_iso) * 0.25, 0.1)
    Bt3 = tl.maximum((B3 + b_iso) * 0.25, 0.1)
    Bt4 = tl.maximum((B4 + b_iso) * 0.25, 0.1)

    # Clamp masks for b gradient (zero grad when clamped)
    clamp0 = ((B0 + b_iso) * 0.25 > 0.1).to(tl.float32)
    clamp1 = ((B1 + b_iso) * 0.25 > 0.1).to(tl.float32)
    clamp2 = ((B2 + b_iso) * 0.25 > 0.1).to(tl.float32)
    clamp3 = ((B3 + b_iso) * 0.25 > 0.1).to(tl.float32)
    clamp4 = ((B4 + b_iso) * 0.25 > 0.1).to(tl.float32)

    pi_1p5: tl.constexpr = 5.568327996831708
    An0 = A0 * occ * pi_1p5 / (Bt0 * tl.sqrt(Bt0))
    An1 = A1 * occ * pi_1p5 / (Bt1 * tl.sqrt(Bt1))
    An2 = A2 * occ * pi_1p5 / (Bt2 * tl.sqrt(Bt2))
    An3 = A3 * occ * pi_1p5 / (Bt3 * tl.sqrt(Bt3))
    An4 = A4 * occ * pi_1p5 / (Bt4 * tl.sqrt(Bt4))

    # Matrices
    if0 = tl.load(inv_frac_ptr + 0); if1 = tl.load(inv_frac_ptr + 1); if2 = tl.load(inv_frac_ptr + 2)
    if3 = tl.load(inv_frac_ptr + 3); if4 = tl.load(inv_frac_ptr + 4); if5 = tl.load(inv_frac_ptr + 5)
    if6 = tl.load(inv_frac_ptr + 6); if7 = tl.load(inv_frac_ptr + 7); if8 = tl.load(inv_frac_ptr + 8)

    f0 = tl.load(frac_ptr + 0); f1 = tl.load(frac_ptr + 1); f2 = tl.load(frac_ptr + 2)
    f3 = tl.load(frac_ptr + 3); f4 = tl.load(frac_ptr + 4); f5 = tl.load(frac_ptr + 5)
    f6 = tl.load(frac_ptr + 6); f7 = tl.load(frac_ptr + 7); f8 = tl.load(frac_ptr + 8)

    pi_sq: tl.constexpr = 9.869604401089358

    # Accumulators for this atom's gradients
    g_ax = 0.0; g_ay = 0.0; g_az = 0.0
    g_b = 0.0
    g_occ = 0.0

    v_offsets = tl.arange(0, BLOCK_V)
    base = atom * N_voxels

    for v_start in range(0, N_voxels, BLOCK_V):
        v = v_start + v_offsets
        mask = v < N_voxels

        # Recompute forward quantities
        idx3 = (base + v) * 3
        sx = tl.load(surr_coords_ptr + idx3 + 0, mask=mask, other=0.0)
        sy = tl.load(surr_coords_ptr + idx3 + 1, mask=mask, other=0.0)
        sz = tl.load(surr_coords_ptr + idx3 + 2, mask=mask, other=0.0)

        dx = sx - ax; dy = sy - ay; dz = sz - az

        fx = dx * if0 + dy * if1 + dz * if2
        fy = dx * if3 + dy * if4 + dz * if5
        fz = dx * if6 + dy * if7 + dz * if8

        tx = tl.extra.cuda.libdevice.round(fx)
        ty = tl.extra.cuda.libdevice.round(fy)
        tz = tl.extra.cuda.libdevice.round(fz)

        cx = tx * f0 + ty * f1 + tz * f2
        cy = tx * f3 + ty * f4 + tz * f5
        cz = tx * f6 + ty * f7 + tz * f8

        wx = dx - cx; wy = dy - cy; wz = dz - cz
        r2 = wx * wx + wy * wy + wz * wz

        # Gather upstream gradient at voxel locations
        vi3 = (base + v) * 3
        ix = tl.load(voxel_idx_ptr + vi3 + 0, mask=mask, other=0).to(tl.int64)
        iy = tl.load(voxel_idx_ptr + vi3 + 1, mask=mask, other=0).to(tl.int64)
        iz = tl.load(voxel_idx_ptr + vi3 + 2, mask=mask, other=0).to(tl.int64)
        flat_idx = ix * (ny * nz) + iy * nz + iz
        grad_out = tl.load(grad_density_map_ptr + flat_idx, mask=mask, other=0.0)

        # Exponentials for each Gaussian
        e0 = tl.exp(-pi_sq * r2 / Bt0)
        e1 = tl.exp(-pi_sq * r2 / Bt1)
        e2 = tl.exp(-pi_sq * r2 / Bt2)
        e3 = tl.exp(-pi_sq * r2 / Bt3)
        e4 = tl.exp(-pi_sq * r2 / Bt4)

        # --- Gradient w.r.t. xyz ---
        # d(density)/d(xyz_i) = 2*pi^2 * diff_wrapped_i * sum_g(A_norm_g * exp_g / B_total_g)
        # (sign: diff = surr - xyz, so d(diff)/d(xyz) = -1, and d(r2)/d(diff_i) = 2*diff_i)
        coeff_xyz = (
            An0 * e0 / Bt0 + An1 * e1 / Bt1 + An2 * e2 / Bt2
            + An3 * e3 / Bt3 + An4 * e4 / Bt4
        )
        scale_xyz = grad_out * 2.0 * pi_sq * coeff_xyz

        # Positive because: d_loss/d_xyz = grad_out * d_density/d_xyz
        # d_density/d_xyz_i = sum_g An_g * exp_g * (-pi_sq / Bt_g) * 2 * w_i * (-1)
        #                   = 2 * pi_sq * w_i * sum_g(An_g * exp_g / Bt_g)
        g_ax += tl.sum(tl.where(mask, scale_xyz * wx, 0.0), axis=0)
        g_ay += tl.sum(tl.where(mask, scale_xyz * wy, 0.0), axis=0)
        g_az += tl.sum(tl.where(mask, scale_xyz * wz, 0.0), axis=0)

        # --- Gradient w.r.t. b ---
        # d(density)/d(B_total_g) = An_g * exp_g * (-1.5/Bt_g + pi_sq*r2/Bt_g^2)
        # d(B_total_g)/d(b) = 0.25 (if not clamped)
        db0 = An0 * e0 * (-1.5 / Bt0 + pi_sq * r2 / (Bt0 * Bt0)) * clamp0
        db1 = An1 * e1 * (-1.5 / Bt1 + pi_sq * r2 / (Bt1 * Bt1)) * clamp1
        db2 = An2 * e2 * (-1.5 / Bt2 + pi_sq * r2 / (Bt2 * Bt2)) * clamp2
        db3 = An3 * e3 * (-1.5 / Bt3 + pi_sq * r2 / (Bt3 * Bt3)) * clamp3
        db4 = An4 * e4 * (-1.5 / Bt4 + pi_sq * r2 / (Bt4 * Bt4)) * clamp4
        g_b += tl.sum(tl.where(mask, grad_out * 0.25 * (db0 + db1 + db2 + db3 + db4), 0.0), axis=0)

        # --- Gradient w.r.t. occ ---
        # d(density)/d(occ) = density / occ  (since An_g is linear in occ)
        density = An0 * e0 + An1 * e1 + An2 * e2 + An3 * e3 + An4 * e4
        # Avoid division by zero; if occ==0 the gradient is the density formula without occ
        g_occ += tl.sum(tl.where(mask, grad_out * tl.where(occ != 0.0, density / occ, 0.0), 0.0), axis=0)

    # Write accumulated gradients
    tl.store(grad_xyz_ptr + atom * 3 + 0, g_ax)
    tl.store(grad_xyz_ptr + atom * 3 + 1, g_ay)
    tl.store(grad_xyz_ptr + atom * 3 + 2, g_az)
    tl.store(grad_b_ptr + atom, g_b)
    tl.store(grad_occ_ptr + atom, g_occ)


# =============================================================================
# Autograd wrapper
# =============================================================================

class _FusedDensityFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        surrounding_coords,  # (N_atoms, N_voxels, 3)
        voxel_indices,       # (N_atoms, N_voxels, 3)
        density_map,         # (nx, ny, nz) — modified in-place
        xyz,                 # (N_atoms, 3)
        b,                   # (N_atoms,)
        inv_frac_matrix,     # (3, 3)
        frac_matrix,         # (3, 3)
        A,                   # (N_atoms, 5)
        B,                   # (N_atoms, 5)
        occ,                 # (N_atoms,)
    ):
        N_atoms, N_voxels = surrounding_coords.shape[:2]
        ny, nz = density_map.shape[1], density_map.shape[2]

        # Ensure contiguous
        surrounding_coords = surrounding_coords.contiguous()
        voxel_indices = voxel_indices.contiguous()
        xyz = xyz.contiguous()
        b = b.contiguous()
        A = A.contiguous()
        B = B.contiguous()
        occ = occ.contiguous()
        inv_frac_flat = inv_frac_matrix.contiguous().view(-1)
        frac_flat = frac_matrix.contiguous().view(-1)

        # Clone so the output is owned by this Function and safe for
        # subsequent in-place ops (e.g. anisotropic scatter_add).
        output = density_map.clone()

        # Choose block size for voxel dimension
        BLOCK_V = triton.next_power_of_2(min(N_voxels, 1024))

        grid = (N_atoms,)
        _density_fwd_kernel[grid](
            surrounding_coords, voxel_indices, output.view(-1),
            xyz, b, inv_frac_flat, frac_flat, A, B, occ,
            N_voxels=N_voxels, ny=ny, nz=nz, BLOCK_V=BLOCK_V,
        )

        # Save for backward
        ctx.save_for_backward(
            surrounding_coords, voxel_indices, xyz, b,
            inv_frac_matrix, frac_matrix, A, B, occ,
        )
        ctx.ny = ny
        ctx.nz = nz
        ctx.density_map_shape = density_map.shape

        return output

    @staticmethod
    def backward(ctx, grad_density_map):
        (surrounding_coords, voxel_indices, xyz, b,
         inv_frac_matrix, frac_matrix, A, B, occ) = ctx.saved_tensors
        ny, nz = ctx.ny, ctx.nz

        N_atoms, N_voxels = surrounding_coords.shape[:2]

        grad_density_map = grad_density_map.contiguous()
        inv_frac_flat = inv_frac_matrix.contiguous().view(-1)
        frac_flat = frac_matrix.contiguous().view(-1)

        grad_xyz = torch.zeros_like(xyz)
        grad_b = torch.zeros_like(b)
        grad_occ = torch.zeros_like(occ)

        BLOCK_V = triton.next_power_of_2(min(N_voxels, 1024))

        grid = (N_atoms,)
        _density_bwd_kernel[grid](
            surrounding_coords, voxel_indices, grad_density_map.view(-1),
            xyz, b, inv_frac_flat, frac_flat, A, B, occ,
            grad_xyz, grad_b, grad_occ,
            N_voxels=N_voxels, ny=ny, nz=nz, BLOCK_V=BLOCK_V,
        )

        # Return gradients in same order as forward args:
        # surrounding_coords, voxel_indices, density_map, xyz, b,
        # inv_frac_matrix, frac_matrix, A, B, occ
        return None, None, None, grad_xyz, grad_b, None, None, None, None, grad_occ


# =============================================================================
# Public API
# =============================================================================

[docs] def fused_add_to_map_gpu( surrounding_coords: torch.Tensor, voxel_indices: torch.Tensor, density_map: torch.Tensor, xyz: torch.Tensor, b: torch.Tensor, inv_frac_matrix: torch.Tensor, frac_matrix: torch.Tensor, A: torch.Tensor, B: torch.Tensor, occ: torch.Tensor, ) -> torch.Tensor: """ Fused GPU density computation using Triton. Drop-in replacement for the JIT GPU kernel with full autograd support. Fuses PBC wrapping, r² computation, 5-Gaussian evaluation, and scatter-add into a single GPU kernel launch. Parameters ---------- surrounding_coords : torch.Tensor Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3). voxel_indices : torch.Tensor Grid indices of voxels, shape (N_atoms, N_voxels, 3). density_map : torch.Tensor Electron density map to update in-place, shape (nx, ny, nz). xyz : torch.Tensor Atom positions in Cartesian space, shape (N_atoms, 3). b : torch.Tensor Isotropic B-factors, shape (N_atoms,). inv_frac_matrix : torch.Tensor Inverse fractionalization matrix, shape (3, 3). frac_matrix : torch.Tensor Fractionalization matrix, shape (3, 3). A : torch.Tensor ITC92 amplitude coefficients, shape (N_atoms, 5). B : torch.Tensor ITC92 width coefficients, shape (N_atoms, 5). occ : torch.Tensor Atomic occupancies, shape (N_atoms,). Returns ------- torch.Tensor Updated density map (modified in-place). """ return _FusedDensityFunction.apply( surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, )
# ============================================================================= # Fused voxel-finding + density computation (skips surrounding_coords entirely) # ============================================================================= _offsets_cache: dict = {} def _compute_local_offsets(voxel_size: torch.Tensor, radius_angstrom: float, device: torch.device) -> torch.Tensor: """Compute and cache spherical voxel offsets (int32).""" key = (device, radius_angstrom, round(voxel_size.min().item(), 6)) if key not in _offsets_cache: min_r = int(torch.ceil(radius_angstrom / voxel_size.min()).item()) g = torch.arange(-min_r, min_r + 1, device=device) x, y, z = torch.meshgrid(g, g, g, indexing="ij") coords = torch.stack((x, y, z), dim=-1) dist = torch.sqrt( torch.sum((coords.float() * voxel_size) ** 2, dim=-1) ) _offsets_cache[key] = coords[dist <= radius_angstrom].to(torch.int32).contiguous() return _offsets_cache[key] @triton.jit def _fused_voxel_fwd_kernel( # Pointers grid_ptr, # real_space_grid flat (nx*ny*nz*3,) float32 density_map_ptr, # (nx*ny*nz,) float32 xyz_ptr, # (N_atoms, 3) float32 b_ptr, # (N_atoms,) float32 inv_frac_ptr, # (9,) float32 row-major frac_ptr, # (9,) float32 row-major A_ptr, # (N_atoms, 5) float32 B_ptr, # (N_atoms, 5) float32 occ_ptr, # (N_atoms,) float32 offsets_ptr, # (N_offsets, 3) int32 # Dimensions nx: tl.constexpr, ny: tl.constexpr, nz: tl.constexpr, N_offsets: tl.constexpr, BLOCK_V: tl.constexpr, ): """One program per atom. Fuses voxel-finding with density computation.""" atom = tl.program_id(0) # Per-atom scalars b_iso = tl.load(b_ptr + atom) occ = tl.load(occ_ptr + atom) ax = tl.load(xyz_ptr + atom * 3 + 0) ay = tl.load(xyz_ptr + atom * 3 + 1) az = tl.load(xyz_ptr + atom * 3 + 2) # ITC92 parameters (5 Gaussians) A0 = tl.load(A_ptr + atom * 5 + 0); A1 = tl.load(A_ptr + atom * 5 + 1) A2 = tl.load(A_ptr + atom * 5 + 2); A3 = tl.load(A_ptr + atom * 5 + 3) A4 = tl.load(A_ptr + atom * 5 + 4) B0 = tl.load(B_ptr + atom * 5 + 0); B1 = tl.load(B_ptr + atom * 5 + 1) B2 = tl.load(B_ptr + atom * 5 + 2); B3 = tl.load(B_ptr + atom * 5 + 3) B4 = tl.load(B_ptr + atom * 5 + 4) Bt0 = tl.maximum((B0 + b_iso) * 0.25, 0.1) Bt1 = tl.maximum((B1 + b_iso) * 0.25, 0.1) Bt2 = tl.maximum((B2 + b_iso) * 0.25, 0.1) Bt3 = tl.maximum((B3 + b_iso) * 0.25, 0.1) Bt4 = tl.maximum((B4 + b_iso) * 0.25, 0.1) pi_1p5: tl.constexpr = 5.568327996831708 An0 = A0 * occ * pi_1p5 / (Bt0 * tl.sqrt(Bt0)) An1 = A1 * occ * pi_1p5 / (Bt1 * tl.sqrt(Bt1)) An2 = A2 * occ * pi_1p5 / (Bt2 * tl.sqrt(Bt2)) An3 = A3 * occ * pi_1p5 / (Bt3 * tl.sqrt(Bt3)) An4 = A4 * occ * pi_1p5 / (Bt4 * tl.sqrt(Bt4)) # Load 3x3 matrices (row-major: M[i,j] = ptr[i*3+j]) if0 = tl.load(inv_frac_ptr + 0); if1 = tl.load(inv_frac_ptr + 1); if2 = tl.load(inv_frac_ptr + 2) if3 = tl.load(inv_frac_ptr + 3); if4 = tl.load(inv_frac_ptr + 4); if5 = tl.load(inv_frac_ptr + 5) if6 = tl.load(inv_frac_ptr + 6); if7 = tl.load(inv_frac_ptr + 7); if8 = tl.load(inv_frac_ptr + 8) f0 = tl.load(frac_ptr + 0); f1 = tl.load(frac_ptr + 1); f2 = tl.load(frac_ptr + 2) f3 = tl.load(frac_ptr + 3); f4 = tl.load(frac_ptr + 4); f5 = tl.load(frac_ptr + 5) f6 = tl.load(frac_ptr + 6); f7 = tl.load(frac_ptr + 7); f8 = tl.load(frac_ptr + 8) pi_sq: tl.constexpr = 9.869604401089358 # --- Fused voxel finding: xyz -> fractional -> grid center index --- # xyz_frac = inv_frac_matrix @ xyz (row-major dot product) frac_x = ax * if0 + ay * if1 + az * if2 frac_y = ax * if3 + ay * if4 + az * if5 frac_z = ax * if6 + ay * if7 + az * if8 # Wrap to [0, 1): x - floor(x) frac_x = frac_x - tl.extra.cuda.libdevice.floor(frac_x) frac_y = frac_y - tl.extra.cuda.libdevice.floor(frac_y) frac_z = frac_z - tl.extra.cuda.libdevice.floor(frac_z) # Round to nearest grid index cix = tl.extra.cuda.libdevice.round(frac_x * nx).to(tl.int32) ciy = tl.extra.cuda.libdevice.round(frac_y * ny).to(tl.int32) ciz = tl.extra.cuda.libdevice.round(frac_z * nz).to(tl.int32) # Process local offsets in blocks v_offsets = tl.arange(0, BLOCK_V) for v_start in range(0, N_offsets, BLOCK_V): v = v_start + v_offsets mask = v < N_offsets # Load spherical offset ob = v * 3 off_x = tl.load(offsets_ptr + ob + 0, mask=mask, other=0) off_y = tl.load(offsets_ptr + ob + 1, mask=mask, other=0) off_z = tl.load(offsets_ptr + ob + 2, mask=mask, other=0) # Wrapped voxel indices (handle negative modulo) vix = (cix + off_x) % nx vix = tl.where(vix < 0, vix + nx, vix) viy = (ciy + off_y) % ny viy = tl.where(viy < 0, viy + ny, viy) viz = (ciz + off_z) % nz viz = tl.where(viz < 0, viz + nz, viz) # Gather Cartesian coords from real_space_grid[vix, viy, viz, :] gf = ((vix * ny + viy) * nz + viz).to(tl.int64) * 3 sx = tl.load(grid_ptr + gf + 0, mask=mask, other=0.0) sy = tl.load(grid_ptr + gf + 1, mask=mask, other=0.0) sz = tl.load(grid_ptr + gf + 2, mask=mask, other=0.0) # diff = surrounding_coord - atom_coord dx = sx - ax; dy = sy - ay; dz = sz - az # PBC: diff_frac = diff @ inv_frac.T fx = dx * if0 + dy * if1 + dz * if2 fy = dx * if3 + dy * if4 + dz * if5 fz = dx * if6 + dy * if7 + dz * if8 tx = tl.extra.cuda.libdevice.round(fx) ty = tl.extra.cuda.libdevice.round(fy) tz = tl.extra.cuda.libdevice.round(fz) cx = tx * f0 + ty * f1 + tz * f2 cy = tx * f3 + ty * f4 + tz * f5 cz = tx * f6 + ty * f7 + tz * f8 wx = dx - cx; wy = dy - cy; wz = dz - cz r2 = wx * wx + wy * wy + wz * wz # 5-Gaussian density density = ( An0 * tl.exp(-pi_sq * r2 / Bt0) + An1 * tl.exp(-pi_sq * r2 / Bt1) + An2 * tl.exp(-pi_sq * r2 / Bt2) + An3 * tl.exp(-pi_sq * r2 / Bt3) + An4 * tl.exp(-pi_sq * r2 / Bt4) ) # Atomic add to density map dm_flat = ((vix * ny + viy) * nz + viz).to(tl.int64) tl.atomic_add(density_map_ptr + dm_flat, density, mask=mask) @triton.jit def _fused_voxel_bwd_kernel( # Forward inputs (read-only) grid_ptr, grad_density_map_ptr, xyz_ptr, b_ptr, inv_frac_ptr, frac_ptr, A_ptr, B_ptr, occ_ptr, offsets_ptr, # Gradient outputs grad_xyz_ptr, grad_b_ptr, grad_occ_ptr, # Dimensions nx: tl.constexpr, ny: tl.constexpr, nz: tl.constexpr, N_offsets: tl.constexpr, BLOCK_V: tl.constexpr, ): """Backward for fused voxel-finding + density. One program per atom.""" atom = tl.program_id(0) b_iso = tl.load(b_ptr + atom) occ = tl.load(occ_ptr + atom) ax = tl.load(xyz_ptr + atom * 3 + 0) ay = tl.load(xyz_ptr + atom * 3 + 1) az = tl.load(xyz_ptr + atom * 3 + 2) A0 = tl.load(A_ptr + atom * 5 + 0); A1 = tl.load(A_ptr + atom * 5 + 1) A2 = tl.load(A_ptr + atom * 5 + 2); A3 = tl.load(A_ptr + atom * 5 + 3) A4 = tl.load(A_ptr + atom * 5 + 4) B0 = tl.load(B_ptr + atom * 5 + 0); B1 = tl.load(B_ptr + atom * 5 + 1) B2 = tl.load(B_ptr + atom * 5 + 2); B3 = tl.load(B_ptr + atom * 5 + 3) B4 = tl.load(B_ptr + atom * 5 + 4) Bt0 = tl.maximum((B0 + b_iso) * 0.25, 0.1) Bt1 = tl.maximum((B1 + b_iso) * 0.25, 0.1) Bt2 = tl.maximum((B2 + b_iso) * 0.25, 0.1) Bt3 = tl.maximum((B3 + b_iso) * 0.25, 0.1) Bt4 = tl.maximum((B4 + b_iso) * 0.25, 0.1) clamp0 = ((B0 + b_iso) * 0.25 > 0.1).to(tl.float32) clamp1 = ((B1 + b_iso) * 0.25 > 0.1).to(tl.float32) clamp2 = ((B2 + b_iso) * 0.25 > 0.1).to(tl.float32) clamp3 = ((B3 + b_iso) * 0.25 > 0.1).to(tl.float32) clamp4 = ((B4 + b_iso) * 0.25 > 0.1).to(tl.float32) pi_1p5: tl.constexpr = 5.568327996831708 An0 = A0 * occ * pi_1p5 / (Bt0 * tl.sqrt(Bt0)) An1 = A1 * occ * pi_1p5 / (Bt1 * tl.sqrt(Bt1)) An2 = A2 * occ * pi_1p5 / (Bt2 * tl.sqrt(Bt2)) An3 = A3 * occ * pi_1p5 / (Bt3 * tl.sqrt(Bt3)) An4 = A4 * occ * pi_1p5 / (Bt4 * tl.sqrt(Bt4)) if0 = tl.load(inv_frac_ptr + 0); if1 = tl.load(inv_frac_ptr + 1); if2 = tl.load(inv_frac_ptr + 2) if3 = tl.load(inv_frac_ptr + 3); if4 = tl.load(inv_frac_ptr + 4); if5 = tl.load(inv_frac_ptr + 5) if6 = tl.load(inv_frac_ptr + 6); if7 = tl.load(inv_frac_ptr + 7); if8 = tl.load(inv_frac_ptr + 8) f0 = tl.load(frac_ptr + 0); f1 = tl.load(frac_ptr + 1); f2 = tl.load(frac_ptr + 2) f3 = tl.load(frac_ptr + 3); f4 = tl.load(frac_ptr + 4); f5 = tl.load(frac_ptr + 5) f6 = tl.load(frac_ptr + 6); f7 = tl.load(frac_ptr + 7); f8 = tl.load(frac_ptr + 8) pi_sq: tl.constexpr = 9.869604401089358 # Recompute center grid index (same as forward) frac_x = ax * if0 + ay * if1 + az * if2 frac_y = ax * if3 + ay * if4 + az * if5 frac_z = ax * if6 + ay * if7 + az * if8 frac_x = frac_x - tl.extra.cuda.libdevice.floor(frac_x) frac_y = frac_y - tl.extra.cuda.libdevice.floor(frac_y) frac_z = frac_z - tl.extra.cuda.libdevice.floor(frac_z) cix = tl.extra.cuda.libdevice.round(frac_x * nx).to(tl.int32) ciy = tl.extra.cuda.libdevice.round(frac_y * ny).to(tl.int32) ciz = tl.extra.cuda.libdevice.round(frac_z * nz).to(tl.int32) # Gradient accumulators g_ax = 0.0; g_ay = 0.0; g_az = 0.0 g_b = 0.0; g_occ = 0.0 v_offsets = tl.arange(0, BLOCK_V) for v_start in range(0, N_offsets, BLOCK_V): v = v_start + v_offsets mask = v < N_offsets ob = v * 3 off_x = tl.load(offsets_ptr + ob + 0, mask=mask, other=0) off_y = tl.load(offsets_ptr + ob + 1, mask=mask, other=0) off_z = tl.load(offsets_ptr + ob + 2, mask=mask, other=0) vix = (cix + off_x) % nx; vix = tl.where(vix < 0, vix + nx, vix) viy = (ciy + off_y) % ny; viy = tl.where(viy < 0, viy + ny, viy) viz = (ciz + off_z) % nz; viz = tl.where(viz < 0, viz + nz, viz) gf = ((vix * ny + viy) * nz + viz).to(tl.int64) * 3 sx = tl.load(grid_ptr + gf + 0, mask=mask, other=0.0) sy = tl.load(grid_ptr + gf + 1, mask=mask, other=0.0) sz = tl.load(grid_ptr + gf + 2, mask=mask, other=0.0) dx = sx - ax; dy = sy - ay; dz = sz - az fx = dx * if0 + dy * if1 + dz * if2 fy = dx * if3 + dy * if4 + dz * if5 fz = dx * if6 + dy * if7 + dz * if8 tx = tl.extra.cuda.libdevice.round(fx) ty = tl.extra.cuda.libdevice.round(fy) tz = tl.extra.cuda.libdevice.round(fz) cx = tx * f0 + ty * f1 + tz * f2 cy = tx * f3 + ty * f4 + tz * f5 cz = tx * f6 + ty * f7 + tz * f8 wx = dx - cx; wy = dy - cy; wz = dz - cz r2 = wx * wx + wy * wy + wz * wz # Gather upstream gradient dm_flat = ((vix * ny + viy) * nz + viz).to(tl.int64) grad_out = tl.load(grad_density_map_ptr + dm_flat, mask=mask, other=0.0) e0 = tl.exp(-pi_sq * r2 / Bt0); e1 = tl.exp(-pi_sq * r2 / Bt1) e2 = tl.exp(-pi_sq * r2 / Bt2); e3 = tl.exp(-pi_sq * r2 / Bt3) e4 = tl.exp(-pi_sq * r2 / Bt4) # grad xyz coeff_xyz = ( An0 * e0 / Bt0 + An1 * e1 / Bt1 + An2 * e2 / Bt2 + An3 * e3 / Bt3 + An4 * e4 / Bt4 ) scale_xyz = grad_out * 2.0 * pi_sq * coeff_xyz g_ax += tl.sum(tl.where(mask, scale_xyz * wx, 0.0), axis=0) g_ay += tl.sum(tl.where(mask, scale_xyz * wy, 0.0), axis=0) g_az += tl.sum(tl.where(mask, scale_xyz * wz, 0.0), axis=0) # grad b db0 = An0 * e0 * (-1.5 / Bt0 + pi_sq * r2 / (Bt0 * Bt0)) * clamp0 db1 = An1 * e1 * (-1.5 / Bt1 + pi_sq * r2 / (Bt1 * Bt1)) * clamp1 db2 = An2 * e2 * (-1.5 / Bt2 + pi_sq * r2 / (Bt2 * Bt2)) * clamp2 db3 = An3 * e3 * (-1.5 / Bt3 + pi_sq * r2 / (Bt3 * Bt3)) * clamp3 db4 = An4 * e4 * (-1.5 / Bt4 + pi_sq * r2 / (Bt4 * Bt4)) * clamp4 g_b += tl.sum(tl.where(mask, grad_out * 0.25 * (db0+db1+db2+db3+db4), 0.0), axis=0) # grad occ density = An0*e0 + An1*e1 + An2*e2 + An3*e3 + An4*e4 g_occ += tl.sum(tl.where(mask, grad_out * tl.where(occ != 0.0, density/occ, 0.0), 0.0), axis=0) tl.store(grad_xyz_ptr + atom * 3 + 0, g_ax) tl.store(grad_xyz_ptr + atom * 3 + 1, g_ay) tl.store(grad_xyz_ptr + atom * 3 + 2, g_az) tl.store(grad_b_ptr + atom, g_b) tl.store(grad_occ_ptr + atom, g_occ) # ============================================================================= # Fused autograd wrapper # ============================================================================= class _FusedVoxelDensityFunction(torch.autograd.Function): @staticmethod def forward(ctx, real_space_grid, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, local_offsets): N_atoms = xyz.shape[0] nx, ny, nz = real_space_grid.shape[:3] N_offsets = local_offsets.shape[0] grid_flat = real_space_grid.contiguous().view(-1) xyz = xyz.contiguous() b = b.contiguous() A = A.contiguous() B = B.contiguous() occ = occ.contiguous() inv_frac_flat = inv_frac_matrix.contiguous().view(-1) frac_flat = frac_matrix.contiguous().view(-1) local_offsets = local_offsets.contiguous() output = density_map.clone() BLOCK_V = triton.next_power_of_2(min(N_offsets, 1024)) _fused_voxel_fwd_kernel[(N_atoms,)]( grid_flat, output.view(-1), xyz, b, inv_frac_flat, frac_flat, A, B, occ, local_offsets, nx=nx, ny=ny, nz=nz, N_offsets=N_offsets, BLOCK_V=BLOCK_V, ) ctx.save_for_backward( real_space_grid, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, local_offsets, ) ctx.density_map_shape = density_map.shape return output @staticmethod def backward(ctx, grad_density_map): (real_space_grid, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, local_offsets) = ctx.saved_tensors N_atoms = xyz.shape[0] nx, ny, nz = real_space_grid.shape[:3] N_offsets = local_offsets.shape[0] grid_flat = real_space_grid.contiguous().view(-1) grad_density_map = grad_density_map.contiguous() inv_frac_flat = inv_frac_matrix.contiguous().view(-1) frac_flat = frac_matrix.contiguous().view(-1) grad_xyz = torch.zeros_like(xyz) grad_b = torch.zeros_like(b) grad_occ = torch.zeros_like(occ) BLOCK_V = triton.next_power_of_2(min(N_offsets, 1024)) _fused_voxel_bwd_kernel[(N_atoms,)]( grid_flat, grad_density_map.view(-1), xyz, b, inv_frac_flat, frac_flat, A, B, occ, local_offsets, grad_xyz, grad_b, grad_occ, nx=nx, ny=ny, nz=nz, N_offsets=N_offsets, BLOCK_V=BLOCK_V, ) # Grads for: real_space_grid, density_map, xyz, b, # inv_frac, frac, A, B, occ, local_offsets return None, None, grad_xyz, grad_b, None, None, None, None, grad_occ, None
[docs] def fused_find_and_place_atoms( real_space_grid: torch.Tensor, density_map: torch.Tensor, xyz: torch.Tensor, b: torch.Tensor, inv_frac_matrix: torch.Tensor, frac_matrix: torch.Tensor, A: torch.Tensor, B: torch.Tensor, occ: torch.Tensor, radius_angstrom: float, voxel_size: torch.Tensor, ) -> torch.Tensor: """ Fused voxel-finding + density computation using Triton. Eliminates the separate find_relevant_voxels step and the large surrounding_coords / voxel_indices intermediate tensors. Computes center grid indices and spherical voxel offsets directly inside the GPU kernel. Parameters ---------- real_space_grid : torch.Tensor Real space coordinate grid, shape (nx, ny, nz, 3). density_map : torch.Tensor Density map to update, shape (nx, ny, nz). xyz, b, inv_frac_matrix, frac_matrix, A, B, occ : Same as fused_add_to_map_gpu. radius_angstrom : float Radius around each atom in Angstroms. voxel_size : torch.Tensor Voxel dimensions, shape (3,). Returns ------- torch.Tensor Updated density map. """ local_offsets = _compute_local_offsets(voxel_size, radius_angstrom, xyz.device) return _FusedVoxelDensityFunction.apply( real_space_grid, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, local_offsets, )