torchref.base.targets.triton.xray_bhattacharyya module
Triton kernels for the Bhattacharyya X-ray target.
Matches torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math()
to within float32 precision. sigma_m enters as a constant input (the
eager target builds it under no_grad).
The math is per-reflection; the kernel reduces nothing — we .sum() the
per-reflection tensor on host so the autograd glue stays trivial.
- torchref.base.targets.triton.xray_bhattacharyya.bhattacharyya_xray_loss_math_triton(F_obs, F_calc, sigma_d, sigma_m, mask)[source]
Triton-backed Bhattacharyya overlap loss.
Drop-in replacement for
torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math().