Source code for torchref.base.targets.triton.xray_gaussian

"""Triton kernels for the Gaussian X-ray target."""

from __future__ import annotations

import math

import torch
import triton
import triton.language as tl


_LOG_2PI = float(math.log(2.0 * math.pi))


@triton.jit
def _gauss_fwd_kernel(
    F_obs_ptr,
    F_calc_ptr,
    sigma_ptr,
    mask_ptr,
    sigma_floor_ptr,    # 0-D tensor = median(sigma) * 1e-1, loaded in-kernel
    log_2pi,            # scalar (compile-time constant — same every call)
    out_ptr,            # (N,)
    N: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    valid = offs < N
    sigma_floor = tl.load(sigma_floor_ptr)

    F_obs = tl.load(F_obs_ptr + offs, mask=valid, other=0.0)
    F_calc = tl.load(F_calc_ptr + offs, mask=valid, other=0.0)
    sig = tl.load(sigma_ptr + offs, mask=valid, other=1.0)
    m = tl.load(mask_ptr + offs, mask=valid, other=0).to(tl.float32)

    sig_safe = tl.where(sig < sigma_floor, sigma_floor, sig)
    diff = F_obs - F_calc
    inv_s = 1.0 / sig_safe
    nll = 0.5 * (diff * inv_s) * (diff * inv_s) + tl.log(sig_safe) + 0.5 * log_2pi
    tl.store(out_ptr + offs, nll * m, mask=valid)


@triton.jit
def _gauss_bwd_kernel(
    F_obs_ptr,
    F_calc_ptr,
    sigma_ptr,
    mask_ptr,
    sigma_floor_ptr,    # 0-D tensor
    grad_out_ptr,       # 0-D tensor
    dF_calc_ptr,
    N: tl.constexpr,
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    valid = offs < N
    sigma_floor = tl.load(sigma_floor_ptr)
    grad_out = tl.load(grad_out_ptr)

    F_obs = tl.load(F_obs_ptr + offs, mask=valid, other=0.0)
    F_calc = tl.load(F_calc_ptr + offs, mask=valid, other=0.0)
    sig = tl.load(sigma_ptr + offs, mask=valid, other=1.0)
    m = tl.load(mask_ptr + offs, mask=valid, other=0).to(tl.float32)

    sig_safe = tl.where(sig < sigma_floor, sigma_floor, sig)
    diff = F_obs - F_calc
    inv_s2 = 1.0 / (sig_safe * sig_safe)
    # dNLL_h / dF_calc = -diff / sigma^2
    g = grad_out * (-diff) * inv_s2 * m
    tl.store(dF_calc_ptr + offs, g, mask=valid)


class _GaussXrayMathTriton(torch.autograd.Function):
    @staticmethod
    def forward(ctx, F_obs, F_calc, sigma, mask):
        assert F_calc.is_cuda and F_calc.dtype == torch.float32
        N = F_calc.shape[0]
        F_obs = F_obs.contiguous()
        F_calc = F_calc.contiguous()
        sigma = sigma.contiguous()
        mask_u8 = mask.to(torch.uint8).contiguous()

        # Match eager: floor = median(sigma) * 1e-1. Keep as 0-D device
        # tensor — no .item() host sync.
        sigma_floor_t = (torch.median(sigma) * 0.1).to(F_calc.dtype)

        out = torch.empty(N, dtype=F_calc.dtype, device=F_calc.device)
        BLOCK = 1024
        grid = (triton.cdiv(N, BLOCK),)
        _gauss_fwd_kernel[grid](
            F_obs, F_calc, sigma, mask_u8, sigma_floor_t, _LOG_2PI, out,
            N=N, BLOCK=BLOCK,
        )
        ctx.save_for_backward(F_obs, F_calc, sigma, mask_u8, sigma_floor_t)
        return out.sum()

    @staticmethod
    def backward(ctx, grad_out):
        F_obs, F_calc, sigma, mask_u8, sigma_floor_t = ctx.saved_tensors
        N = F_calc.shape[0]
        dF_calc = torch.empty_like(F_calc)
        BLOCK = 1024
        grid = (triton.cdiv(N, BLOCK),)
        _gauss_bwd_kernel[grid](
            F_obs, F_calc, sigma, mask_u8, sigma_floor_t,
            grad_out, dF_calc,
            N=N, BLOCK=BLOCK,
        )
        return None, dF_calc, None, None


[docs] def gaussian_xray_loss_math_triton(F_obs, F_calc, sigma, mask): """Triton-backed Gaussian NLL X-ray loss.""" return _GaussXrayMathTriton.apply(F_obs, F_calc, sigma, mask)