import torch
from typing import TYPE_CHECKING, Dict
from torchref.utils.stats import (
VERBOSITY_DEBUG,
VERBOSITY_DETAILED,
VERBOSITY_STANDARD,
stat,
)
from ..base import Target
if TYPE_CHECKING:
from torchref.scaling.scaler_base import ScalerBase
[docs]
class ScalerLogScaleTrendTarget(Target):
"""
Pin out the Debye-Waller-like trend in the per-bin ``log_scale``.
Per-bin scales exist to absorb **localized** data-quality issues —
an outlier bin, a slightly mis-merged shell. They are *not* supposed
to absorb the global resolution-dependent attenuation envelope; that
is what atomic B-factors are for. A ``log_scale[i] ≈ a + b·s²[i]``
structure with nonzero ``b`` is exactly a B-factor masquerading as a
per-bin scale: it shifts ``B_eff = B_atom − 4·b``, letting atoms drift
broader or sharper while the overall ``F_scaled`` amplitudes stay
unchanged.
The target fits the least-squares slope of ``log_scale`` against the
bin-mean ``|s|²`` and penalizes ``slope²``. The intercept ``a`` (the
overall scale) is free. Residuals off the fit line are also free —
those encode the outlier absorption the scales are *supposed* to do.
Penalty ``slope² · N_ref / nbins`` — the ``/ nbins`` matches how xray
gradient on a single ``log_scale[i]`` scales; the ``N_ref`` factor
brings the total into xray's order of magnitude.
"""
name: str = "adp/scaler_log_scale"
[docs]
def __init__(
self,
scaler: "ScalerBase",
n_reflections: int,
verbose: int = 0,
):
super().__init__(verbose=verbose)
object.__setattr__(self, "_scaler", scaler)
with torch.no_grad():
s_sq = (scaler.s ** 2).sum(dim=1)
bins = scaler.bins.to(torch.int64)
nbins = int(scaler.nbins)
bin_s_sq = torch.zeros(nbins, device=scaler.s.device, dtype=s_sq.dtype)
bin_counts = torch.zeros_like(bin_s_sq)
bin_s_sq.scatter_add_(0, bins, s_sq)
bin_counts.scatter_add_(0, bins, torch.ones_like(s_sq))
bin_mean_s_sq = bin_s_sq / bin_counts.clamp(min=1.0)
valid = bin_counts > 0
s_centers = bin_mean_s_sq[valid]
x = s_centers - s_centers.mean()
x_var = (x * x).sum().clamp(min=1e-12)
self.register_buffer("_valid_mask", valid)
self.register_buffer("_x_centered", x)
self.register_buffer("_x_var", x_var)
self.register_buffer("_s_centers", s_centers)
self._nbins = int(valid.sum().item())
self._scale = float(n_reflections) / max(self._nbins, 1)
def _slope(self) -> torch.Tensor:
log_scale = self._scaler.log_scale[self._valid_mask]
y = log_scale - log_scale.mean()
return (self._x_centered * y).sum() / self._x_var
[docs]
def forward(self) -> torch.Tensor:
slope = self._slope()
return self._scale * slope ** 2
[docs]
def stats(self) -> Dict[str, any]:
with torch.no_grad():
slope = self._slope().item()
log_scale = self._scaler.log_scale.detach()
b_equiv = -4.0 * slope
return {
"loss": stat(self._scale * slope ** 2, VERBOSITY_STANDARD),
"slope": stat(slope, VERBOSITY_STANDARD),
"B_equiv": stat(b_equiv, VERBOSITY_STANDARD),
"log_scale_mean": stat(log_scale.mean().item(), VERBOSITY_DETAILED),
"log_scale_range": stat(
(log_scale.max() - log_scale.min()).item(), VERBOSITY_DETAILED
),
}