torchref.base.targets.triton.xray_bhattacharyya module

Triton kernels for the Bhattacharyya X-ray target.

Matches torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math() to within float32 precision. sigma_m enters as a constant input (the eager target builds it under no_grad).

The math is per-reflection; the kernel reduces nothing — we .sum() the per-reflection tensor on host so the autograd glue stays trivial.

torchref.base.targets.triton.xray_bhattacharyya.bhattacharyya_xray_loss_math_triton(F_obs, F_calc, sigma_d, sigma_m, mask)[source]

Triton-backed Bhattacharyya overlap loss.

Drop-in replacement for torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math().