Source code for torchref.refinement.targets.adp.scaler_u

import torch
from typing import TYPE_CHECKING, Dict

from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    stat,
)

from ..base import Target

if TYPE_CHECKING:
    from torchref.scaling.scaler_base import ScalerBase


[docs] class ScalerURegularizationTarget(Target): """ Pin the isotropic component of the scaler's anisotropic U to zero. ``F_scaled = scale · exp(-2π²·sᵀUs) · F_calc`` couples ``U`` into the same Debye-Waller slot as the atomic B-factors. The *anisotropic* deviatoric part of U encodes real directional attenuation in the data and should stay free. The *isotropic* part — the trace ``U11 + U22 + U33`` — is a straight trade with a uniform atomic B shift and must be anchored so the atoms absorb the B-factor roll-off, not the scaler. Penalty ``(tr U)² · N_ref / 6`` — squared trace, scaled by the order of magnitude of the xray gradient on U. The ``/ 6`` matches the per-component gradient scaling the user originally requested. """ name: str = "adp/scaler_U"
[docs] def __init__( self, scaler: "ScalerBase", n_reflections: int, verbose: int = 0, ): super().__init__(verbose=verbose) object.__setattr__(self, "_scaler", scaler) self._scale = float(n_reflections) / 6.0
[docs] def forward(self) -> torch.Tensor: U = self._scaler.U trace = U[0] + U[1] + U[2] return self._scale * trace ** 2
[docs] def stats(self) -> Dict[str, any]: U = self._scaler.U.detach() trace = (U[0] + U[1] + U[2]).item() u_diag = U[:3] u_off = U[3:] aniso_norm = torch.norm(u_diag - u_diag.mean()).item() return { "loss": stat(self._scale * trace ** 2, VERBOSITY_STANDARD), "tr_U": stat(trace, VERBOSITY_STANDARD), "U_iso": stat(trace / 3.0, VERBOSITY_DETAILED), "U_aniso_norm": stat(aniso_norm, VERBOSITY_DETAILED), "U_off_max": stat(u_off.abs().max().item(), VERBOSITY_DETAILED), "U_vec": stat(U.tolist(), VERBOSITY_DEBUG), }