Source code for torchref.base.targets.triton.place_hydrogens

"""Triton forward + analytic backward for riding-hydrogen placement.

The eager helper (``_place_h_jit`` in :mod:`torchref.restraints.hydrogen_topology`)
is a ``@torch.jit.script`` that fuses the forward pass into ~30 launches.
Its backward, however, flows through PyTorch autograd op-by-op (cross
products, normalisations, ``where``-branch routing); for ~3 k hydrogens
that's ~100 launches and dominates the non-bonded backward cost (1.9 ms
out of 5.3 ms total fwd+bw on 1DAW / A100).

This kernel does both forward and analytic backward in one launch each.
The math mirrors ``_place_h_jit`` exactly:

    pp        = xyz[parent_idx]
    nb_pos[i] = xyz[nb_idx[i]]
    v[i]      = (nb_pos[i] − pp) · valid[i]
    s         = Σ_i v[i]
    base      = −s / (|s| + ε)

    cross12        = v1 × v2
    cross_norm     = |cross12|
    perp1_cross    = cross12 / (cross_norm + ε)
    cardinal       = ê_argmin|base|     (detached — no grad)
    perp1_ortho    = (base × cardinal) / |base × cardinal|
    perp1          = perp1_cross  if cross_norm > 1e-6  else  perp1_ortho
    perp2          = base × perp1

    direction = c0·base + c1·perp1 + c2·perp2
    H         = pp + L · direction
"""

from __future__ import annotations

import torch
import triton
import triton.language as tl


_EPS = 1e-8


@triton.jit
def _placeh_fwd_kernel(
    xyz_ptr,            # (N_heavy, 3)
    parent_ptr,         # (N_h,) int
    nb_idx_ptr,         # (N_h, 4) int (clamped to >=0)
    nb_valid_ptr,       # (N_h, 4) float (1.0 / 0.0)
    coeff_ptr,          # (N_h, 3) float
    blen_ptr,           # (N_h,) float
    out_ptr,            # (N_h, 3) float
    N_H: tl.constexpr,
    EPS: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    m = offs < N_H

    pi = tl.load(parent_ptr + offs, mask=m, other=0)
    ppx = tl.load(xyz_ptr + pi * 3 + 0, mask=m, other=0.0)
    ppy = tl.load(xyz_ptr + pi * 3 + 1, mask=m, other=0.0)
    ppz = tl.load(xyz_ptr + pi * 3 + 2, mask=m, other=0.0)

    # Neighbour 0..3 — accumulate s and stash v1 and v2 for the cross product.
    nb0 = tl.load(nb_idx_ptr + offs * 4 + 0, mask=m, other=0)
    nb1 = tl.load(nb_idx_ptr + offs * 4 + 1, mask=m, other=0)
    nb2 = tl.load(nb_idx_ptr + offs * 4 + 2, mask=m, other=0)
    nb3 = tl.load(nb_idx_ptr + offs * 4 + 3, mask=m, other=0)
    vd0 = tl.load(nb_valid_ptr + offs * 4 + 0, mask=m, other=0.0)
    vd1 = tl.load(nb_valid_ptr + offs * 4 + 1, mask=m, other=0.0)
    vd2 = tl.load(nb_valid_ptr + offs * 4 + 2, mask=m, other=0.0)
    vd3 = tl.load(nb_valid_ptr + offs * 4 + 3, mask=m, other=0.0)

    v0x = (tl.load(xyz_ptr + nb0 * 3 + 0, mask=m, other=0.0) - ppx) * vd0
    v0y = (tl.load(xyz_ptr + nb0 * 3 + 1, mask=m, other=0.0) - ppy) * vd0
    v0z = (tl.load(xyz_ptr + nb0 * 3 + 2, mask=m, other=0.0) - ppz) * vd0
    v1x = (tl.load(xyz_ptr + nb1 * 3 + 0, mask=m, other=0.0) - ppx) * vd1
    v1y = (tl.load(xyz_ptr + nb1 * 3 + 1, mask=m, other=0.0) - ppy) * vd1
    v1z = (tl.load(xyz_ptr + nb1 * 3 + 2, mask=m, other=0.0) - ppz) * vd1
    v2x = (tl.load(xyz_ptr + nb2 * 3 + 0, mask=m, other=0.0) - ppx) * vd2
    v2y = (tl.load(xyz_ptr + nb2 * 3 + 1, mask=m, other=0.0) - ppy) * vd2
    v2z = (tl.load(xyz_ptr + nb2 * 3 + 2, mask=m, other=0.0) - ppz) * vd2
    v3x = (tl.load(xyz_ptr + nb3 * 3 + 0, mask=m, other=0.0) - ppx) * vd3
    v3y = (tl.load(xyz_ptr + nb3 * 3 + 1, mask=m, other=0.0) - ppy) * vd3
    v3z = (tl.load(xyz_ptr + nb3 * 3 + 2, mask=m, other=0.0) - ppz) * vd3

    sx = v0x + v1x + v2x + v3x
    sy = v0y + v1y + v2y + v3y
    sz = v0z + v1z + v2z + v3z
    s_norm = tl.sqrt(sx * sx + sy * sy + sz * sz)
    bx = -sx / (s_norm + EPS)
    by = -sy / (s_norm + EPS)
    bz = -sz / (s_norm + EPS)

    # cross12 = v0 × v1 (the eager helper uses indices [:,0,:] and [:,1,:])
    cx = v0y * v1z - v0z * v1y
    cy = v0z * v1x - v0x * v1z
    cz = v0x * v1y - v0y * v1x
    c_norm = tl.sqrt(cx * cx + cy * cy + cz * cz)
    has_cross = c_norm > 1e-6
    p1cx = cx / (c_norm + EPS)
    p1cy = cy / (c_norm + EPS)
    p1cz = cz / (c_norm + EPS)

    # cardinal axis: one-hot at argmin |base|. Detached in eager (argmin
    # has no usable gradient anyway), so we compute it with comparisons
    # instead of a scatter.
    abx = tl.abs(bx); aby = tl.abs(by); abz = tl.abs(bz)
    sel_x = (abx <= aby) & (abx <= abz)
    sel_y = (~sel_x) & (aby <= abz)
    # else z
    car_x = tl.where(sel_x, 1.0, 0.0)
    car_y = tl.where(sel_y, 1.0, 0.0)
    car_z = tl.where(~(sel_x | sel_y), 1.0, 0.0)

    # perp1_ortho = normalize(base × cardinal)
    rox = by * car_z - bz * car_y
    roy = bz * car_x - bx * car_z
    roz = bx * car_y - by * car_x
    r_norm = tl.sqrt(rox * rox + roy * roy + roz * roz)
    p1ox = rox / (r_norm + EPS)
    p1oy = roy / (r_norm + EPS)
    p1oz = roz / (r_norm + EPS)

    p1x_ = tl.where(has_cross, p1cx, p1ox)
    p1y_ = tl.where(has_cross, p1cy, p1oy)
    p1z_ = tl.where(has_cross, p1cz, p1oz)

    # perp2 = base × perp1
    p2x_ = by * p1z_ - bz * p1y_
    p2y_ = bz * p1x_ - bx * p1z_
    p2z_ = bx * p1y_ - by * p1x_

    c0 = tl.load(coeff_ptr + offs * 3 + 0, mask=m, other=0.0)
    c1 = tl.load(coeff_ptr + offs * 3 + 1, mask=m, other=0.0)
    c2 = tl.load(coeff_ptr + offs * 3 + 2, mask=m, other=0.0)
    L  = tl.load(blen_ptr + offs, mask=m, other=0.0)

    dxv = c0 * bx + c1 * p1x_ + c2 * p2x_
    dyv = c0 * by + c1 * p1y_ + c2 * p2y_
    dzv = c0 * bz + c1 * p1z_ + c2 * p2z_

    hx = ppx + L * dxv
    hy = ppy + L * dyv
    hz = ppz + L * dzv

    tl.store(out_ptr + offs * 3 + 0, hx, mask=m)
    tl.store(out_ptr + offs * 3 + 1, hy, mask=m)
    tl.store(out_ptr + offs * 3 + 2, hz, mask=m)


@triton.jit
def _placeh_bwd_kernel(
    xyz_ptr,
    parent_ptr,
    nb_idx_ptr,
    nb_valid_ptr,
    coeff_ptr,
    blen_ptr,
    grad_h_ptr,         # (N_h, 3) — d(loss)/d(xyz_h)
    dxyz_ptr,           # (N_heavy, 3) — accumulator
    N_H: tl.constexpr,
    EPS: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    m = offs < N_H

    pi = tl.load(parent_ptr + offs, mask=m, other=0)
    ppx = tl.load(xyz_ptr + pi * 3 + 0, mask=m, other=0.0)
    ppy = tl.load(xyz_ptr + pi * 3 + 1, mask=m, other=0.0)
    ppz = tl.load(xyz_ptr + pi * 3 + 2, mask=m, other=0.0)

    nb0 = tl.load(nb_idx_ptr + offs * 4 + 0, mask=m, other=0)
    nb1 = tl.load(nb_idx_ptr + offs * 4 + 1, mask=m, other=0)
    nb2 = tl.load(nb_idx_ptr + offs * 4 + 2, mask=m, other=0)
    nb3 = tl.load(nb_idx_ptr + offs * 4 + 3, mask=m, other=0)
    vd0 = tl.load(nb_valid_ptr + offs * 4 + 0, mask=m, other=0.0)
    vd1 = tl.load(nb_valid_ptr + offs * 4 + 1, mask=m, other=0.0)
    vd2 = tl.load(nb_valid_ptr + offs * 4 + 2, mask=m, other=0.0)
    vd3 = tl.load(nb_valid_ptr + offs * 4 + 3, mask=m, other=0.0)

    v0x = (tl.load(xyz_ptr + nb0 * 3 + 0, mask=m, other=0.0) - ppx) * vd0
    v0y = (tl.load(xyz_ptr + nb0 * 3 + 1, mask=m, other=0.0) - ppy) * vd0
    v0z = (tl.load(xyz_ptr + nb0 * 3 + 2, mask=m, other=0.0) - ppz) * vd0
    v1x = (tl.load(xyz_ptr + nb1 * 3 + 0, mask=m, other=0.0) - ppx) * vd1
    v1y = (tl.load(xyz_ptr + nb1 * 3 + 1, mask=m, other=0.0) - ppy) * vd1
    v1z = (tl.load(xyz_ptr + nb1 * 3 + 2, mask=m, other=0.0) - ppz) * vd1
    v2x = (tl.load(xyz_ptr + nb2 * 3 + 0, mask=m, other=0.0) - ppx) * vd2
    v2y = (tl.load(xyz_ptr + nb2 * 3 + 1, mask=m, other=0.0) - ppy) * vd2
    v2z = (tl.load(xyz_ptr + nb2 * 3 + 2, mask=m, other=0.0) - ppz) * vd2
    v3x = (tl.load(xyz_ptr + nb3 * 3 + 0, mask=m, other=0.0) - ppx) * vd3
    v3y = (tl.load(xyz_ptr + nb3 * 3 + 1, mask=m, other=0.0) - ppy) * vd3
    v3z = (tl.load(xyz_ptr + nb3 * 3 + 2, mask=m, other=0.0) - ppz) * vd3

    sx = v0x + v1x + v2x + v3x
    sy = v0y + v1y + v2y + v3y
    sz = v0z + v1z + v2z + v3z
    s_norm = tl.sqrt(sx * sx + sy * sy + sz * sz)
    inv_sn = 1.0 / (s_norm + EPS)
    bx = -sx * inv_sn
    by = -sy * inv_sn
    bz = -sz * inv_sn

    cx = v0y * v1z - v0z * v1y
    cy = v0z * v1x - v0x * v1z
    cz = v0x * v1y - v0y * v1x
    c_norm = tl.sqrt(cx * cx + cy * cy + cz * cz)
    inv_cn = 1.0 / (c_norm + EPS)
    has_cross = c_norm > 1e-6
    p1cx = cx * inv_cn
    p1cy = cy * inv_cn
    p1cz = cz * inv_cn

    abx = tl.abs(bx); aby = tl.abs(by); abz = tl.abs(bz)
    sel_x = (abx <= aby) & (abx <= abz)
    sel_y = (~sel_x) & (aby <= abz)
    car_x = tl.where(sel_x, 1.0, 0.0)
    car_y = tl.where(sel_y, 1.0, 0.0)
    car_z = tl.where(~(sel_x | sel_y), 1.0, 0.0)
    rox = by * car_z - bz * car_y
    roy = bz * car_x - bx * car_z
    roz = bx * car_y - by * car_x
    r_norm = tl.sqrt(rox * rox + roy * roy + roz * roz)
    inv_rn = 1.0 / (r_norm + EPS)
    p1ox = rox * inv_rn
    p1oy = roy * inv_rn
    p1oz = roz * inv_rn

    p1x_ = tl.where(has_cross, p1cx, p1ox)
    p1y_ = tl.where(has_cross, p1cy, p1oy)
    p1z_ = tl.where(has_cross, p1cz, p1oz)

    p2x_ = by * p1z_ - bz * p1y_
    p2y_ = bz * p1x_ - bx * p1z_
    p2z_ = bx * p1y_ - by * p1x_

    c0 = tl.load(coeff_ptr + offs * 3 + 0, mask=m, other=0.0)
    c1 = tl.load(coeff_ptr + offs * 3 + 1, mask=m, other=0.0)
    c2 = tl.load(coeff_ptr + offs * 3 + 2, mask=m, other=0.0)
    L  = tl.load(blen_ptr + offs, mask=m, other=0.0)

    ghx = tl.load(grad_h_ptr + offs * 3 + 0, mask=m, other=0.0)
    ghy = tl.load(grad_h_ptr + offs * 3 + 1, mask=m, other=0.0)
    ghz = tl.load(grad_h_ptr + offs * 3 + 2, mask=m, other=0.0)

    # G_pp_direct = ghx,ghy,ghz   (from H = pp + ...)
    # G_dir = L * gh              (from H = pp + L*dir)
    Gdx = L * ghx; Gdy = L * ghy; Gdz = L * ghz

    # dir = c0*base + c1*perp1 + c2*perp2
    Gb_x = c0 * Gdx; Gb_y = c0 * Gdy; Gb_z = c0 * Gdz
    Gp1x = c1 * Gdx; Gp1y = c1 * Gdy; Gp1z = c1 * Gdz
    Gp2x = c2 * Gdx; Gp2y = c2 * Gdy; Gp2z = c2 * Gdz

    # perp2 = base × perp1  →  G_base += perp1 × G_perp2 ; G_perp1 += G_perp2 × base
    Gb_x += p1y_ * Gp2z - p1z_ * Gp2y
    Gb_y += p1z_ * Gp2x - p1x_ * Gp2z
    Gb_z += p1x_ * Gp2y - p1y_ * Gp2x
    Gp1x += Gp2y * bz - Gp2z * by
    Gp1y += Gp2z * bx - Gp2x * bz
    Gp1z += Gp2x * by - Gp2y * bx

    # Split G_perp1 by the where branch.
    Gp1cx = tl.where(has_cross, Gp1x, 0.0)
    Gp1cy = tl.where(has_cross, Gp1y, 0.0)
    Gp1cz = tl.where(has_cross, Gp1z, 0.0)
    Gp1ox = tl.where(has_cross, 0.0, Gp1x)
    Gp1oy = tl.where(has_cross, 0.0, Gp1y)
    Gp1oz = tl.where(has_cross, 0.0, Gp1z)

    # perp1_cross = cross12 / (|cross12| + eps).
    # G_cross12 = (G − (G·perp1_cross) perp1_cross) / (|cross12| + eps)
    dot_pc = Gp1cx * p1cx + Gp1cy * p1cy + Gp1cz * p1cz
    Gcx = (Gp1cx - dot_pc * p1cx) * inv_cn
    Gcy = (Gp1cy - dot_pc * p1cy) * inv_cn
    Gcz = (Gp1cz - dot_pc * p1cz) * inv_cn

    # cross12 = v0 × v1
    Gv0x = v1y * Gcz - v1z * Gcy
    Gv0y = v1z * Gcx - v1x * Gcz
    Gv0z = v1x * Gcy - v1y * Gcx
    Gv1x = Gcy * v0z - Gcz * v0y
    Gv1y = Gcz * v0x - Gcx * v0z
    Gv1z = Gcx * v0y - Gcy * v0x

    # perp1_ortho = normalise(base × cardinal). Cardinal detached.
    dot_po = Gp1ox * p1ox + Gp1oy * p1oy + Gp1oz * p1oz
    Grox = (Gp1ox - dot_po * p1ox) * inv_rn
    Groy = (Gp1oy - dot_po * p1oy) * inv_rn
    Groz = (Gp1oz - dot_po * p1oz) * inv_rn
    # perp1_ortho_raw = base × cardinal  →  G_base += cardinal × G_raw
    Gb_x += car_y * Groz - car_z * Groy
    Gb_y += car_z * Grox - car_x * Groz
    Gb_z += car_x * Groy - car_y * Grox

    # base = -s / (|s| + eps)
    dot_sb = Gb_x * bx + Gb_y * by + Gb_z * bz
    Gsx = -(Gb_x - dot_sb * bx) * inv_sn
    Gsy = -(Gb_y - dot_sb * by) * inv_sn
    Gsz = -(Gb_z - dot_sb * bz) * inv_sn

    # s = Σ_i v[i] → G_v[i] += G_s
    # plus the cross-product contributions on v0, v1
    Gv0x += Gsx; Gv0y += Gsy; Gv0z += Gsz
    Gv1x += Gsx; Gv1y += Gsy; Gv1z += Gsz
    Gv2x = Gsx;  Gv2y = Gsy;  Gv2z = Gsz
    Gv3x = Gsx;  Gv3y = Gsy;  Gv3z = Gsz

    # v[i] = (nb_pos[i] − pp) · valid[i]
    # G_nb_pos[i] = valid[i] · G_v[i]
    # G_pp_extra = − Σ_i valid[i] · G_v[i]
    Gnb0x = vd0 * Gv0x; Gnb0y = vd0 * Gv0y; Gnb0z = vd0 * Gv0z
    Gnb1x = vd1 * Gv1x; Gnb1y = vd1 * Gv1y; Gnb1z = vd1 * Gv1z
    Gnb2x = vd2 * Gv2x; Gnb2y = vd2 * Gv2y; Gnb2z = vd2 * Gv2z
    Gnb3x = vd3 * Gv3x; Gnb3y = vd3 * Gv3y; Gnb3z = vd3 * Gv3z

    Gpx = ghx - (Gnb0x + Gnb1x + Gnb2x + Gnb3x)
    Gpy = ghy - (Gnb0y + Gnb1y + Gnb2y + Gnb3y)
    Gpz = ghz - (Gnb0z + Gnb1z + Gnb2z + Gnb3z)

    # Scatter to xyz_heavy
    tl.atomic_add(dxyz_ptr + pi * 3 + 0, Gpx, mask=m)
    tl.atomic_add(dxyz_ptr + pi * 3 + 1, Gpy, mask=m)
    tl.atomic_add(dxyz_ptr + pi * 3 + 2, Gpz, mask=m)
    tl.atomic_add(dxyz_ptr + nb0 * 3 + 0, Gnb0x, mask=m)
    tl.atomic_add(dxyz_ptr + nb0 * 3 + 1, Gnb0y, mask=m)
    tl.atomic_add(dxyz_ptr + nb0 * 3 + 2, Gnb0z, mask=m)
    tl.atomic_add(dxyz_ptr + nb1 * 3 + 0, Gnb1x, mask=m)
    tl.atomic_add(dxyz_ptr + nb1 * 3 + 1, Gnb1y, mask=m)
    tl.atomic_add(dxyz_ptr + nb1 * 3 + 2, Gnb1z, mask=m)
    tl.atomic_add(dxyz_ptr + nb2 * 3 + 0, Gnb2x, mask=m)
    tl.atomic_add(dxyz_ptr + nb2 * 3 + 1, Gnb2y, mask=m)
    tl.atomic_add(dxyz_ptr + nb2 * 3 + 2, Gnb2z, mask=m)
    tl.atomic_add(dxyz_ptr + nb3 * 3 + 0, Gnb3x, mask=m)
    tl.atomic_add(dxyz_ptr + nb3 * 3 + 1, Gnb3y, mask=m)
    tl.atomic_add(dxyz_ptr + nb3 * 3 + 2, Gnb3z, mask=m)


class _PlaceHydrogensTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length):
        assert xyz_heavy.is_cuda and xyz_heavy.dtype == torch.float32
        N_h = parent_idx.shape[0]
        # nb_valid was (N_h, 4, 1); flatten to (N_h, 4)
        if nb_valid.dim() == 3:
            nb_valid = nb_valid.squeeze(-1)
        # bond_length was (N_h, 1); flatten to (N_h,)
        if bond_length.dim() == 2:
            bond_length = bond_length.squeeze(-1)
        out = torch.empty(N_h, 3, dtype=xyz_heavy.dtype, device=xyz_heavy.device)
        BLOCK = 128
        grid = (triton.cdiv(N_h, BLOCK),)
        _placeh_fwd_kernel[grid](
            xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length, out,
            N_H=N_h, EPS=_EPS, BLOCK=BLOCK,
        )
        ctx.save_for_backward(xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length)
        ctx.n_heavy = xyz_heavy.shape[0]
        return out

    @staticmethod
    def backward(ctx, grad_h):
        xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length = ctx.saved_tensors
        N_h = parent_idx.shape[0]
        dxyz = torch.zeros_like(xyz_heavy)
        grad_h = grad_h.contiguous()
        BLOCK = 128
        grid = (triton.cdiv(N_h, BLOCK),)
        _placeh_bwd_kernel[grid](
            xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length,
            grad_h, dxyz,
            N_H=N_h, EPS=_EPS, BLOCK=BLOCK,
        )
        return dxyz, None, None, None, None, None


[docs] def place_riding_hydrogens_triton( xyz_heavy: torch.Tensor, parent_idx: torch.Tensor, nb_idx_clamped: torch.Tensor, nb_valid: torch.Tensor, coeffs: torch.Tensor, bond_length: torch.Tensor, ) -> torch.Tensor: """Triton-backed riding-hydrogen placement. Drop-in replacement for ``_place_h_jit`` (forward) plus an analytic Triton backward that avoids autograd's op-by-op traversal. """ return _PlaceHydrogensTriton.apply( xyz_heavy, parent_idx, nb_idx_clamped, nb_valid, coeffs, bond_length, )