"""Bhattacharyya overlap X-ray loss math.
Mirrors the loss computation in :class:`BhattacharyyaXrayTarget.forward` after
``sigma_m`` has been computed in a no-grad block (so ``sigma_m`` is treated
as a constant input here).
"""
import torch
from ._dispatch import use_triton
def _bhattacharyya_xray_loss_math_eager(
F_obs: torch.Tensor,
F_calc: torch.Tensor,
sigma_d: torch.Tensor,
sigma_m: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
F_calc_amp = torch.abs(F_calc)
eps = 1e-6
sigma_d_safe = torch.clamp(sigma_d, min=eps)
sigma_m_safe = torch.clamp(sigma_m, min=eps)
var_d = sigma_d_safe ** 2
var_m = sigma_m_safe ** 2
var_sum = var_d + var_m
diff = F_obs - F_calc_amp
l_mean = (diff ** 2) / (4.0 * var_sum)
l_var = 0.5 * torch.log(var_sum / (2.0 * sigma_d_safe * sigma_m_safe))
return ((l_mean + l_var) * mask).sum()
[docs]
def bhattacharyya_xray_loss_math(
F_obs: torch.Tensor,
F_calc: torch.Tensor,
sigma_d: torch.Tensor,
sigma_m: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""Bhattacharyya overlap loss between data and model Gaussians.
L_h = (F_obs - |F_calc|)^2 / (4 * (sigma_d^2 + sigma_m^2))
+ 0.5 * log((sigma_d^2 + sigma_m^2) / (2 * sigma_d * sigma_m))
Dispatches to :func:`torchref.base.targets.triton.xray_bhattacharyya.bhattacharyya_xray_loss_math_triton`
on CUDA float32; falls back to the eager implementation otherwise.
"""
if use_triton(F_calc, F_obs, sigma_d, sigma_m):
from .triton.xray_bhattacharyya import bhattacharyya_xray_loss_math_triton
return bhattacharyya_xray_loss_math_triton(F_obs, F_calc, sigma_d, sigma_m, mask)
return _bhattacharyya_xray_loss_math_eager(F_obs, F_calc, sigma_d, sigma_m, mask)