Source code for torchref.base.targets.triton.xray_bhattacharyya

"""Triton kernels for the Bhattacharyya X-ray target.

Matches :func:`torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math`
to within float32 precision. ``sigma_m`` enters as a constant input (the
eager target builds it under ``no_grad``).

The math is per-reflection; the kernel reduces nothing — we ``.sum()`` the
per-reflection tensor on host so the autograd glue stays trivial.
"""

from __future__ import annotations

import torch
import triton
import triton.language as tl


_EPS = 1e-6


@triton.jit
def _bhatt_fwd_kernel(
    F_obs_ptr,        # (N,)
    F_calc_ptr,       # (N,) real non-negative amplitudes
    sigma_d_ptr,      # (N,)
    sigma_m_ptr,      # (N,)
    mask_ptr,         # (N,) bool
    out_ptr,          # (N,) per-reflection loss
    N: tl.constexpr,
    EPS: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    valid = offs < N

    F_obs = tl.load(F_obs_ptr + offs, mask=valid, other=0.0)
    F_calc = tl.load(F_calc_ptr + offs, mask=valid, other=0.0)
    sd = tl.load(sigma_d_ptr + offs, mask=valid, other=1.0)
    sm = tl.load(sigma_m_ptr + offs, mask=valid, other=1.0)
    m = tl.load(mask_ptr + offs, mask=valid, other=0).to(tl.float32)

    sd_s = tl.where(sd < EPS, EPS, sd)
    sm_s = tl.where(sm < EPS, EPS, sm)
    var_d = sd_s * sd_s
    var_m = sm_s * sm_s
    var_sum = var_d + var_m

    diff = F_obs - F_calc
    l_mean = (diff * diff) / (4.0 * var_sum)
    l_var = 0.5 * tl.log(var_sum / (2.0 * sd_s * sm_s))

    loss = (l_mean + l_var) * m
    tl.store(out_ptr + offs, loss, mask=valid)


@triton.jit
def _bhatt_bwd_kernel(
    F_obs_ptr,
    F_calc_ptr,
    sigma_d_ptr,
    sigma_m_ptr,
    mask_ptr,
    grad_out_ptr,    # 0-D tensor — loaded in-kernel (no host .item() sync)
    dF_calc_ptr,     # (N,) — written directly (not atomic, one program per range)
    N: tl.constexpr,
    EPS: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    valid = offs < N
    grad_out = tl.load(grad_out_ptr)

    F_obs = tl.load(F_obs_ptr + offs, mask=valid, other=0.0)
    F_calc = tl.load(F_calc_ptr + offs, mask=valid, other=0.0)
    sd = tl.load(sigma_d_ptr + offs, mask=valid, other=1.0)
    sm = tl.load(sigma_m_ptr + offs, mask=valid, other=1.0)
    m = tl.load(mask_ptr + offs, mask=valid, other=0).to(tl.float32)

    sd_s = tl.where(sd < EPS, EPS, sd)
    sm_s = tl.where(sm < EPS, EPS, sm)
    var_sum = sd_s * sd_s + sm_s * sm_s

    diff = F_obs - F_calc
    # dL_h / dF_calc = -diff / (2 * var_sum) (only l_mean depends on F_calc)
    dl = -diff / (2.0 * var_sum)
    g = grad_out * dl * m
    tl.store(dF_calc_ptr + offs, g, mask=valid)


class _BhattXrayMathTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, F_obs, F_calc, sigma_d, sigma_m, mask):
        assert F_calc.is_cuda and F_calc.dtype == torch.float32
        N = F_calc.shape[0]
        # Ensure contiguous, float32 inputs (loads assume stride 1).
        F_obs = F_obs.contiguous()
        F_calc = F_calc.contiguous()
        sigma_d = sigma_d.contiguous()
        sigma_m = sigma_m.contiguous()
        mask_u8 = mask.to(torch.uint8).contiguous()

        out = torch.empty(N, dtype=F_calc.dtype, device=F_calc.device)
        BLOCK = 1024
        grid = (triton.cdiv(N, BLOCK),)
        _bhatt_fwd_kernel[grid](
            F_obs, F_calc, sigma_d, sigma_m, mask_u8, out,
            N=N, EPS=_EPS, BLOCK=BLOCK,
        )
        ctx.save_for_backward(F_obs, F_calc, sigma_d, sigma_m, mask_u8)
        return out.sum()

    @staticmethod
    def backward(ctx, grad_out):
        F_obs, F_calc, sigma_d, sigma_m, mask_u8 = ctx.saved_tensors
        N = F_calc.shape[0]
        dF_calc = torch.empty_like(F_calc)
        BLOCK = 1024
        grid = (triton.cdiv(N, BLOCK),)
        _bhatt_bwd_kernel[grid](
            F_obs, F_calc, sigma_d, sigma_m, mask_u8,
            grad_out, dF_calc,
            N=N, EPS=_EPS, BLOCK=BLOCK,
        )
        return None, dF_calc, None, None, None


[docs] def bhattacharyya_xray_loss_math_triton( F_obs: torch.Tensor, F_calc: torch.Tensor, sigma_d: torch.Tensor, sigma_m: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """Triton-backed Bhattacharyya overlap loss. Drop-in replacement for :func:`torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math`. """ return _BhattXrayMathTriton.apply(F_obs, F_calc, sigma_d, sigma_m, mask)