torchref.base.targets.xray_bhattacharyya module
Bhattacharyya overlap X-ray loss math.
Mirrors the loss computation in BhattacharyyaXrayTarget.forward after
sigma_m has been computed in a no-grad block (so sigma_m is treated
as a constant input here).
- torchref.base.targets.xray_bhattacharyya.bhattacharyya_xray_loss_math(F_obs, F_calc, sigma_d, sigma_m, mask)[source]
Bhattacharyya overlap loss between data and model Gaussians.
- L_h = (F_obs - |F_calc|)^2 / (4 * (sigma_d^2 + sigma_m^2))
0.5 * log((sigma_d^2 + sigma_m^2) / (2 * sigma_d * sigma_m))
Dispatches to
torchref.base.targets.triton.xray_bhattacharyya.bhattacharyya_xray_loss_math_triton()on CUDA float32; falls back to the eager implementation otherwise.