Source code for torchref.base.targets.xray_gaussian

"""Gaussian X-ray loss math.

Caller is responsible for producing the post-``XrayTarget.get_data``
tensors (F_obs, F_calc, sigma, mask).

The public entry point ``gaussian_xray_loss_math`` dispatches to the
Triton kernel on CUDA float32 and to the eager implementation
otherwise, matching the pattern used by ``bond_math``, ``angle_math``,
etc.
"""

import numpy as np
import torch

from ._dispatch import use_triton


def _gaussian_xray_loss_math_eager(
    F_obs: torch.Tensor,
    F_calc: torch.Tensor,
    sigma: torch.Tensor,
    mask: torch.Tensor,
) -> torch.Tensor:
    F_calc_amp = torch.abs(F_calc)
    diff = F_obs - F_calc_amp

    eps = torch.median(sigma) * 1e-1
    sigma_safe = torch.clamp(sigma, min=eps)

    log_2pi = torch.log(
        torch.tensor(2.0 * np.pi, device=sigma.device, dtype=sigma.dtype)
    )
    nll = 0.5 * (diff ** 2) / (sigma_safe ** 2) + torch.log(sigma_safe) + 0.5 * log_2pi
    return (nll * mask).sum()


[docs] def gaussian_xray_loss_math( F_obs: torch.Tensor, F_calc: torch.Tensor, sigma: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """Gaussian NLL on already-scaled amplitudes. Dispatches to :func:`torchref.base.targets.triton.xray_gaussian.gaussian_xray_loss_math_triton` on CUDA float32 inputs; falls back to the eager implementation otherwise. """ if use_triton(F_calc, F_obs, sigma): from .triton.xray_gaussian import gaussian_xray_loss_math_triton return gaussian_xray_loss_math_triton(F_obs, F_calc, sigma, mask) return _gaussian_xray_loss_math_eager(F_obs, F_calc, sigma, mask)