"""Maximum-likelihood X-ray loss math.
Mirrors lines 47-84 of
``torchref/refinement/targets/xray/maximum_likelihood.py`` verbatim. The
caller is responsible for everything ``XrayTarget.get_data`` does:
unpacking ``ReflectionData``, running the ``Scaler`` forward to produce
``|F_calc|`` from a complex ``f_calc``, and building the work/free mask.
"""
import numpy as np
import torch
from ._dispatch import use_triton
def _ml_xray_loss_math_eager(
F_obs: torch.Tensor,
F_calc: torch.Tensor,
sigma: torch.Tensor,
centric_flags: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
alpha = torch.ones_like(F_obs)
beta = sigma ** 2
epsilon = torch.ones_like(F_obs)
if centric_flags is None:
centric_flags = torch.zeros_like(F_obs, dtype=torch.bool)
F_calc_amp = torch.abs(F_calc)
eb = epsilon * beta
eb = torch.clamp(eb, min=1e-6)
term1 = -torch.log(2 * F_obs / eb + 1e-12)
term2 = (F_obs ** 2) / eb
term3 = (alpha * F_calc_amp) ** 2 / eb
arg_bessel = 2 * alpha * F_obs * F_calc_amp / eb
arg_bessel = torch.clamp(arg_bessel, max=1e6)
term4 = -(torch.log(torch.special.i0e(arg_bessel) + 1e-12) + arg_bessel)
loss_acentric = term1 + term2 + term3 + term4
term1_c = -0.5 * torch.log(2 / (np.pi * eb) + 1e-12)
term2_c = (F_obs ** 2) / (2 * eb)
term3_c = (alpha * F_calc_amp) ** 2 / (2 * eb)
term4_c = -(alpha * F_obs * F_calc_amp) / eb
arg_exp = -2 * alpha * F_obs * F_calc_amp / eb
arg_exp_safe = torch.clamp(arg_exp, min=-80.0, max=80.0)
term5_c = -torch.log((1 + torch.exp(arg_exp_safe)) / 2 + 1e-12)
loss_centric = term1_c + term2_c + term3_c + term4_c + term5_c
loss = torch.where(centric_flags, loss_centric, loss_acentric)
loss = torch.where(torch.isfinite(loss), loss, torch.full_like(loss, 1e6))
return (loss * mask).sum()
[docs]
def ml_xray_loss_math(
F_obs: torch.Tensor,
F_calc: torch.Tensor,
sigma: torch.Tensor,
centric_flags: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""Maximum-likelihood X-ray loss on already-scaled amplitudes.
Matches ``MaximumLikelihoodXrayTarget.forward`` lines 37-84.
Dispatches to :func:`torchref.base.targets.triton.xray_ml.ml_xray_loss_math_triton`
on CUDA float32; falls back to the eager implementation otherwise.
Parameters
----------
F_obs : torch.Tensor
(N,) observed amplitudes (zeros outside ``mask``).
F_calc : torch.Tensor
(N,) scaled calculated amplitudes (already real-valued, zeros
outside ``mask``).
sigma : torch.Tensor
(N,) per-reflection sigma (ones outside ``mask``).
centric_flags : torch.Tensor or None
(N,) bool, True for centric reflections. ``None`` is treated as all-acentric.
mask : torch.Tensor
(N,) bool work-set mask applied to the final sum.
"""
if use_triton(F_calc, F_obs, sigma):
from .triton.xray_ml import ml_xray_loss_math_triton
return ml_xray_loss_math_triton(F_obs, F_calc, sigma, centric_flags, mask)
return _ml_xray_loss_math_eager(F_obs, F_calc, sigma, centric_flags, mask)