"""
Base scaler class for crystallographic scaling without model dependency.
This module provides `ScalerBase`, a class that implements all scaling functionality
but does NOT maintain a reference to a Model object. All methods that require
calculated structure factors (F_calc) take them as input arguments.
This allows the scaler to be used independently of any specific model implementation,
making it suitable for use cases like:
- Molecular replacement where F_calc comes from external sources
- Testing and validation with precomputed structure factors
- Custom model implementations
"""
from typing import Optional, Tuple, TYPE_CHECKING
import torch
import torch.nn as nn
from torchref.base.math_torch import U_to_matrix
from torchref.base.metrics import bin_wise_rfactors, get_rfactors, nll_xray, nll_xray_lognormal
from torchref.base.reciprocal import get_scattering_vectors
from torchref.config import get_complex_dtype, get_default_device, get_float_dtype
from torchref.utils.autograd_ops import gather_with_index_add
from torchref.utils.debug_utils import DebugMixin
from torchref.utils.utils import ModuleReference
from torchref.utils.device_mixin import DeviceMixin
from torchref.utils.device_resolution import resolve_device
if TYPE_CHECKING:
from torchref.io import ReflectionData
from torchref.scaling.solvent import SolventModel
[docs]
class ScalerBase(DeviceMixin, DebugMixin, nn.Module):
"""
Base scaler class for crystallographic scaling without model dependency.
All methods that require calculated structure factors (F_calc) take them
as input arguments. This allows the scaler to be used independently of
any specific model implementation.
Supports two initialization patterns:
1. Empty initialization (for state_dict loading)::
scaler = ScalerBase() # Creates empty shell
scaler.load_state_dict(torch.load('scaler.pt'))
2. Full initialization with data::
scaler = ScalerBase(data=reflection_data, nbins=20)
scaler.initialize(fcalc)
Parameters
----------
data : ReflectionData, optional
ReflectionData object with observed data.
nbins : int, default 20
Number of resolution bins.
verbose : int, default 1
Verbosity level.
device : torch.device, default: configured device.current
Computation device.
Attributes
----------
device : torch.device
Current computation device.
nbins : int
Number of resolution bins.
"""
[docs]
def __init__(
self,
data: Optional["ReflectionData"] = None,
nbins: int = 20,
verbose: int = 1,
device: Optional[torch.device] = None,
):
"""
Initialize ScalerBase.
If data is provided, fully initializes the scaler.
If not provided (empty init), creates a shell ready for load_state_dict().
Parameters
----------
data : ReflectionData, optional
ReflectionData object with observed data.
nbins : int, default 20
Number of resolution bins.
verbose : int, default 1
Verbosity level.
device : torch.device, optional
Computation device. If ``None``, derived from ``data`` (if
given) or the configured default via
:func:`torchref.utils.resolve_device`. An explicit value
forces ``data`` onto that device.
"""
super(ScalerBase, self).__init__()
self.device = resolve_device(data, device=device)
self.verbose = verbose
self.nbins = nbins
# Empty initialization - just set up configuration
if data is None:
self._data = None
self.cell = None
self.register_buffer("s", None)
self.register_buffer("bins", None)
self.register_buffer("_s_half_sq", None)
self.register_buffer("sigma_eff", None)
self.register_buffer("sigma_eff_per_bin", None)
self._f_sol_raw = None
return
# Full initialization with data
self.to(self.device)
self._data = ModuleReference(data)
self.cell = data.cell
s = get_scattering_vectors(data.hkl, self.cell)
self.register_buffer("s", s)
# Precompute (sin(θ)/λ)² for B-factor damping — avoids recomputing per call
self.register_buffer("_s_half_sq", (torch.norm(s, dim=1) / 2.0) ** 2)
self._f_sol_raw = None
bins, self.nbins = self._data.get_bins(self.nbins)
self.register_buffer("bins", bins)
# Effective sigma buffers (populated by estimate_sigma_eff after scaling)
# sigma_eff: per-reflection effective sigma, shape (N,)
# sigma_eff_per_bin: per-bin effective sigma, shape (nbins,)
# Initialized to raw sigmas; will be updated after scaling.
_, _, sigma_raw, _ = self._data(mask=False)
self.register_buffer("sigma_eff", sigma_raw.clone().to(self.device))
self.register_buffer(
"sigma_eff_per_bin",
torch.zeros(self.nbins, device=self.device, dtype=sigma_raw.dtype),
)
if self.verbose > 0:
print(f"Initialized ScalerBase with {self.nbins} bins.")
[docs]
def set_data(self, data: "ReflectionData"):
"""
Set data reference after empty initialization.
This is useful when loading from state_dict and then needing
to reconnect to a data object.
Parameters
----------
data : ReflectionData
ReflectionData object with observed data.
"""
self._data = ModuleReference(data)
if data.cell is not None:
self.cell = data.cell
if self.s is None and data.hkl is not None and data.cell is not None:
s = get_scattering_vectors(data.hkl, data.cell)
self.register_buffer("s", s)
self.register_buffer("_s_half_sq", (torch.norm(s, dim=1) / 2.0) ** 2)
self._f_sol_raw = None
if self.bins is None and data.hkl is not None:
bins, self.nbins = self._data.get_bins(self.nbins)
self.register_buffer("bins", bins)
[docs]
def initialize(self, fcalc: torch.Tensor):
"""
Initialize scaling parameters using provided F_calc.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
"""
self.calc_initial_scale(fcalc)
self.setup_anisotropy_correction()
@property
def hkl(self):
"""Get HKL indices from data."""
return self._data.hkl
[docs]
def calc_initial_scale(self, fcalc: torch.Tensor):
"""
Calculate the initial scale factor based on the ratio of observed to calculated structure factors.
Excludes reflections with negative intensities to avoid bias from French-Wilson conversion.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
Returns
-------
torch.nn.Parameter
The log scale parameter for each resolution bin.
"""
hkl, fobs, sigma, rfree = self._data(mask=False)
if self.verbose > 0:
print(f"Calculating initial scale factors using {self.nbins} bins.")
assert torch.all(
torch.isfinite(fcalc)
), "Non-finite values found in fcalc during initial scale calculation."
scales = torch.zeros(self.nbins, device=self.device, dtype=fobs.dtype)
counts = torch.zeros(self.nbins, device=self.device, dtype=fobs.dtype)
fcalc_amp = torch.abs(fcalc).to(fobs.dtype)
# Exclude reflections with negative intensities from scale calculation
# These have biased F values from French-Wilson conversion
if hasattr(self._data, "I") and self._data.I is not None:
positive_mask = self._data.I > 0
if self.verbose > 1:
n_excluded = (~positive_mask).sum().item()
print(
f"Excluding {n_excluded} negative intensity reflections from scale calculation"
)
else:
positive_mask = torch.ones_like(fobs, dtype=torch.bool)
# Calculate ratios only for positive intensity reflections
mask = (self._data.masks() & rfree & positive_mask).to(torch.bool)
bins = self.bins[mask].to(torch.int64)
fobs = fobs.clamp(min=1e-3)[mask]
fcalc_amp = fcalc_amp.clamp(min=1e-3)[mask]
log_ratios = torch.log(fobs) - torch.log(fcalc_amp)
assert torch.all(
torch.isfinite(log_ratios)
), f"Non-finite log ratios encountered in initial scale calculation {torch.sum(~torch.isfinite(log_ratios)).item()}"
# Ensure all tensors are on the same device for scatter_add
log_ratios = log_ratios.to(self.device)
bins = bins.to(self.device)
counts_vals = torch.ones_like(self.bins, device=self.device, dtype=fobs.dtype)
sum_log_scales = torch.scatter_add(scales, 0, bins, log_ratios)
counts = torch.scatter_add(counts, 0, bins, counts_vals)
log_scale = sum_log_scales / (counts + 1e-6)
initial_log_scale = log_scale
if self.verbose > 1:
print(
"Initial scale factors per bin:",
initial_log_scale.detach().cpu().numpy(),
)
self.log_scale = nn.Parameter(initial_log_scale.detach().to(self.device))
return self.log_scale
[docs]
def setup_anisotropy_correction(self):
"""Initialize anisotropic correction parameters."""
self.U = nn.Parameter(
torch.normal(0, 0.001, (6,), dtype=get_float_dtype(), device=self.device)
)
[docs]
def anisotropy_correction(self):
"""
Compute anisotropic correction factors.
Returns
-------
torch.Tensor
Anisotropic correction factors for each reflection.
"""
U = U_to_matrix(self.U)
# matmul + element-wise multiply + sum is much faster than einsum
# for this bilinear form s^T U s on CPU (avoids einsum dispatch overhead)
sU = torch.matmul(self.s, U) # (N, 3)
exp = -2 * torch.pi**2 * (sU * self.s).sum(dim=1)
return torch.exp(exp.clamp(max=10.0, min=-10.0))
[docs]
def fit_anisotropy(self, fcalc: torch.Tensor, nsteps: int = 100):
"""
Fit anisotropic correction using provided F_calc.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
nsteps : int, default 100
Number of optimization steps.
"""
if not hasattr(self, "U"):
self.U = nn.Parameter(
torch.normal(0, 0.01, (6,), dtype=get_float_dtype(), device=self.device)
)
hkl, fobs, sigma, rfree = self._data()
fobs = fobs.to(get_float_dtype()).detach()
fcalc = torch.abs(fcalc).to(get_float_dtype()).detach()
optimizer = torch.optim.Adam([self.U, self.log_scale], lr=1e-1)
for i in range(nsteps):
optimizer.zero_grad()
scaled_fcalc = self.forward(fcalc)
loss = nll_xray(fobs[rfree], scaled_fcalc[rfree], sigma[rfree])
loss.backward()
optimizer.step()
if self.verbose > 0 and (i % 10 == 0 or i == nsteps - 1):
print(f"Anisotropy fit iteration {i+1}/{nsteps}, Loss: {loss.item():.4f}")
[docs]
def set_solvent_model(self, solvent_model: "SolventModel") -> None:
"""
Set a pre-configured SolventModel for solvent contribution.
The SolventModel must be initialized externally (requires a Model object).
Parameters
----------
solvent_model : SolventModel
Pre-configured solvent model that can compute solvent structure factors.
"""
self.solvent = solvent_model
self._f_sol_raw = None # Invalidate cached raw solvent SFs
[docs]
def setup_binwise_solvent_scale(self):
"""
Setup bin-wise solvent scaling (Phenix-style kmask per bin).
This allows finer control over solvent contribution per resolution bin,
which is more flexible than a single global B_sol parameter.
"""
mean_res = self._data.mean_res_per_bin()
# Initialize with exponential decay: k_mask = k_sol * exp(-B * s^2)
# Use B=46 as starting point (Phenix-like)
s_per_bin = 1.0 / (2.0 * mean_res + 1e-6) # sin(theta)/lambda
initial_kmask = 0.35 * torch.exp(-46.0 * s_per_bin**2)
# Set high-res bins to 0 (where kmask < 0.05)
initial_kmask = torch.where(
initial_kmask < 0.05, torch.zeros_like(initial_kmask), initial_kmask
)
self.log_kmask = nn.Parameter(
torch.log(initial_kmask.clamp(min=1e-6) + 1e-6).to(self.device)
)
[docs]
def fit_all_scales(self, fcalc: torch.Tensor):
"""
Fit all scale parameters using provided F_calc.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
"""
hkl, fobs, sigma, rfree = self._data()
fobs = fobs.to(get_float_dtype()).detach()
fcalc = fcalc.detach()
for lr in [1e-1, 5e-2, 1e-2]:
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
for i in range(20):
optimizer.zero_grad()
scaled_fcalc = self.forward(fcalc)
nll_loss = nll_xray(fobs[rfree], scaled_fcalc[rfree], sigma[rfree])
if torch.isnan(nll_loss):
raise ValueError(
"NaN encountered in NLL loss during scale fitting."
)
nll_log_loss_xray = nll_xray_lognormal(
fobs[rfree], scaled_fcalc[rfree], sigma[rfree]
)
loss = nll_loss
loss.backward()
optimizer.step()
if self.verbose > 1:
print(
f"Solvent fit after step, Loss: {loss.item():.4f}, NLL: {nll_loss.item():.4f}, LogLoss: {nll_log_loss_xray.item():.4f}"
)
[docs]
def fit_simple(self, fobs: torch.Tensor, fcalc: torch.Tensor) -> None:
"""
Fit a single global scale factor analytically (least-squares).
This is the simple scaling approach:
k = sum(|F_obs||F_calc|) / sum(|F_calc|²)
Useful for rigid body refinement where only an overall scale is needed.
Parameters
----------
fobs : torch.Tensor
Observed structure factor amplitudes.
fcalc : torch.Tensor
Calculated structure factors (complex).
"""
fcalc_amp = torch.abs(fcalc)
fobs_amp = torch.abs(fobs)
# Analytical least-squares solution
numerator = torch.sum(fobs_amp * fcalc_amp)
denominator = torch.sum(fcalc_amp**2)
# Avoid division by zero
scale = numerator / denominator.clamp(min=1e-10)
# Initialize log_scale if not present
if not hasattr(self, "log_scale") or self.log_scale is None:
self.log_scale = nn.Parameter(
torch.zeros(self.nbins, dtype=fobs.dtype, device=self.device)
)
# Store as log_scale (broadcast single value to all bins)
with torch.no_grad():
self.log_scale.fill_(torch.log(scale.clamp(min=1e-6)))
[docs]
def get_scale(self) -> float:
"""
Get the current overall scale factor value.
Returns the mean scale factor across all bins.
Returns
-------
float
Current scale factor (not log).
"""
if hasattr(self, "log_scale") and self.log_scale is not None:
return torch.exp(self.log_scale.mean().clamp(-10, 10)).item()
return 1.0
[docs]
def rfactor(self, fcalc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the R-factor between observed and calculated structure factors.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
Returns
-------
tuple
R-work and R-free values.
"""
hkl, fobs, _, rfree = self._data()
fcalc_scaled = self.forward(fcalc)
if hasattr(fobs, "get_data"):
valid = fobs.get_mask()
fobs = fobs.get_data()[valid]
fcalc_scaled = fcalc_scaled[valid]
rfree = rfree[valid]
return get_rfactors(torch.abs(fobs), torch.abs(fcalc_scaled), rfree)
[docs]
def bin_wise_rfactor(self, fcalc: torch.Tensor):
"""
Calculate the bin-wise R-factor between observed and calculated structure factors.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
Returns
-------
mean_res_per_bin : torch.Tensor
Mean resolution per bin.
rwork_per_bin : torch.Tensor
R-work per bin.
rfree_per_bin : torch.Tensor
R-free per bin.
"""
hkl, fobs, _, rfree = self._data()
fcalc_scaled = self.forward(fcalc)
if hasattr(fobs, "get_data"):
valid = fobs.get_mask()
fobs = fobs.get_data()[valid]
fcalc_scaled = fcalc_scaled[valid]
rfree = rfree[valid]
mean_res_per_bin = self._data.mean_res_per_bin()
return mean_res_per_bin, *bin_wise_rfactors(
torch.abs(fobs),
torch.abs(fcalc_scaled),
rfree,
self.bins[self._data.masks()],
)
[docs]
def setup_bin_wise_bfactor(self):
"""Initialize bin-wise B-factor correction parameters."""
self.bin_wise_bfactor = nn.Parameter(
torch.zeros(self.nbins, dtype=get_float_dtype(), device=self.device)
)
[docs]
def bin_wise_bfactor_correction(self):
"""
Compute bin-wise B-factor correction factors.
Returns
-------
torch.Tensor
B-factor correction factors for each reflection.
"""
# Index-add-backward gather: ``bin_wise_bfactor`` is an O(nbins)
# learnable parameter; the default ``[bins]`` backward sorts all
# N_refl indices before scattering. ``gather_with_index_add`` skips
# the sort.
b_expanded = gather_with_index_add(self.bin_wise_bfactor, self.bins)
s = torch.norm(self.s, dim=1)
s_squared = s**2
exp = -b_expanded * s_squared / 4
return torch.exp(exp.clamp(max=10.0, min=-10.0))
[docs]
def get_binwise_mean_intensity(self, fcalc: torch.Tensor):
"""
Get bin-wise mean intensities for observed and calculated structure factors.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
Returns
-------
tuple
Mean observed intensity, mean calculated intensity, and mean resolution per bin.
"""
hkl, fobs, _, rfree = self._data()
F_calc = torch.abs(self(fcalc))
intensities = torch.abs(fobs) ** 2
calc_intensities = torch.abs(F_calc) ** 2
mean_obs_intensity = torch.zeros(self.nbins, device=self.device)
mean_calc_intensity = torch.zeros(self.nbins, device=self.device)
counts = torch.zeros(self.nbins, device=self.device)
counts_vals = torch.ones_like(F_calc, device=self.device, dtype=fobs.dtype)
mask = self._data.get_mask()
mean_obs_intensity = torch.scatter_add(
mean_obs_intensity,
0,
self.bins.to(torch.int64)[mask][rfree],
intensities[rfree],
)
mean_calc_intensity = torch.scatter_add(
mean_calc_intensity,
0,
self.bins.to(torch.int64)[mask][rfree],
calc_intensities[rfree],
)
counts = torch.scatter_add(
counts, 0, self.bins.to(torch.int64)[mask][rfree], counts_vals[rfree]
)
mean_obs_intensity = mean_obs_intensity / (counts + 1e-6)
mean_calc_intensity = mean_calc_intensity / (counts + 1e-6)
return mean_obs_intensity, mean_calc_intensity, self._data.mean_res_per_bin()
[docs]
def screen_solvent_params(
self,
fcalc: torch.Tensor,
steps: int = 15,
use_low_res_weighting: bool = True,
low_res_cutoff: float = 5.0,
fit_on_low_res_only: bool = True,
low_res_limit: float = 3.5,
):
"""
Screen solvent parameters (k_sol, B_sol) using grid search.
The bulk solvent contributes primarily at low resolution. Fitting on low-resolution
reflections only (fit_on_low_res_only=True) prevents high-resolution reflections
from dominating the optimization and pushing B_sol too low.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
steps : int, default 15
Number of grid points for each parameter.
use_low_res_weighting : bool, default True
If True, weight low-resolution reflections more heavily
since solvent primarily contributes at low resolution.
low_res_cutoff : float, default 5.0
Resolution cutoff for weighting in Angstroms.
fit_on_low_res_only : bool, default True
If True, fit using only low-resolution reflections.
low_res_limit : float, default 3.5
Resolution limit for low-res only fitting in Angstroms.
"""
if not hasattr(self, "solvent") or self.solvent is None:
raise RuntimeError("No solvent model set. Call set_solvent_model() first.")
hkl, fobs, sigma, rfree = self._data()
fobs = fobs.to(get_float_dtype()).detach()
fcalc = fcalc.detach()
# Calculate resolution for weighting/filtering
s = torch.norm(get_scattering_vectors(hkl, self.cell), dim=1)
resolution = 1.0 / (s + 1e-6)
# Create mask for low-resolution reflections
if fit_on_low_res_only:
low_res_mask = (resolution > low_res_limit) & rfree
n_low_res = low_res_mask.sum().item()
if self.verbose > 1:
print(
f"Solvent screening using {n_low_res} low-res reflections (>{low_res_limit}Å)"
)
if n_low_res < 100:
print(
f"Warning: Only {n_low_res} low-res reflections, using all reflections instead"
)
fit_on_low_res_only = False
if not fit_on_low_res_only:
low_res_mask = rfree
# Create weights for low-resolution preference
if use_low_res_weighting:
weights = torch.exp(-s * low_res_cutoff).detach()
weights = weights / weights[low_res_mask].sum()
if self.verbose > 1:
low_res_frac = (resolution > low_res_cutoff).float().mean()
print(
f"Low-resolution weighting: {low_res_frac*100:.1f}% reflections above {low_res_cutoff}Å"
)
else:
weights = torch.ones_like(fobs)
weights = weights / weights[low_res_mask].sum()
best_log_k_solvent = self.solvent.log_k_solvent.clone()
best_b_solvent = self.solvent.b_solvent.clone()
best_loss = float("inf")
ksol_start = torch.log(torch.tensor(0.1, device=self.device))
ksol_end = torch.log(torch.tensor(0.6, device=self.device))
for log_k_solvent in torch.linspace(
ksol_start, ksol_end, steps=steps, device=self.device
):
for b_solvent in torch.linspace(
30.0, 100.0, steps=steps, device=self.device
):
self.solvent.log_k_solvent.data = log_k_solvent.to(
dtype=self.solvent.log_k_solvent.dtype
)
self.solvent.b_solvent.data = b_solvent.to(
dtype=self.solvent.b_solvent.dtype
)
scaled_fcalc = self.forward(fcalc)
diff = fobs[low_res_mask] - torch.abs(scaled_fcalc[low_res_mask])
sigma_subset = sigma[low_res_mask]
if hasattr(sigma_subset, "get_mask"):
sigma_data = sigma_subset.get_data()[sigma_subset.get_mask()]
eps = torch.median(sigma_data).item() * 1e-1
else:
eps = torch.median(sigma_subset).item() * 1e-1
sigma_safe = torch.clamp(sigma_subset, min=eps)
nll_per_refl = 0.5 * (diff**2) / (sigma_safe**2)
if use_low_res_weighting:
nll_loss = (nll_per_refl * weights[low_res_mask]).sum()
else:
nll_loss = nll_per_refl.mean()
if nll_loss.item() < best_loss:
best_loss = nll_loss.item()
best_log_k_solvent = log_k_solvent.clone()
best_b_solvent = b_solvent.clone()
self.solvent.log_k_solvent.data = best_log_k_solvent.to(
dtype=self.solvent.log_k_solvent.dtype
)
self.solvent.b_solvent.data = best_b_solvent.to(
dtype=self.solvent.b_solvent.dtype
)
if self.verbose > 0:
k_sol = torch.exp(best_log_k_solvent).item()
print(
f"Optimal solvent parameters found: k_sol={k_sol:.4f}, B_sol={best_b_solvent.item():.1f}, "
f"NLL Loss={best_loss:.4f}"
)
[docs]
def refine_lbfgs(
self,
fcalc: torch.Tensor,
nsteps: int = 3,
lr: float = 1.0,
max_iter: int = 200,
history_size: int = 10,
verbose: bool = True,
):
"""
Refine scale parameters using LBFGS optimizer.
This method optimizes the anisotropic scaling and B-factor parameters
that relate calculated structure factors to observed structure factors.
Uses the L-BFGS quasi-Newton optimization method for fast convergence.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex).
nsteps : int, default 3
Number of LBFGS steps.
lr : float, default 1.0
Learning rate (typically 1.0 for LBFGS).
max_iter : int, default 200
Maximum iterations per line search.
history_size : int, default 10
Number of previous gradients to store for Hessian approximation.
verbose : bool, default True
Print progress information.
Returns
-------
dict
Dictionary with refinement metrics including steps, xray_work,
xray_test, rwork, rfree.
"""
from torchref.refinement.loss_state import LossState
hkl, fobs, sigma, rfree_mask = self._data()
fcalc = fcalc.detach()
# Wrap the scaler loss as a LossState target so this path uses the
# same closure/NaN-validation/auto-freeze infrastructure as the main
# refinement loop. Because fcalc is detached, the only leaves the
# loss touches are the scaler's own parameters — LossState's probe
# picks those up at register time. validate_loss inside state.step
# handles NaN/Inf rejection so no per-target try/except is needed.
scaler_self = self
class _ScalerXrayTarget(nn.Module):
name = "scaler/xray"
def forward(self):
fcalc_scaled = scaler_self.forward(fcalc)
u_penalty = torch.sum(scaler_self.U**2)
return nll_xray(fobs, fcalc_scaled, sigma) + u_penalty
state = LossState(device=self.device)
state.register_target("scaler/xray", _ScalerXrayTarget())
# Create LBFGS optimizer for scaler parameters only
optimizer = torch.optim.LBFGS(
self.parameters(),
lr=lr,
max_iter=max_iter,
history_size=history_size,
line_search_fn="strong_wolfe",
)
# Track metrics
metrics = {
"target": "scales",
"steps": [],
"xray_work": [],
"xray_test": [],
"rwork": [],
"rfree": [],
}
if verbose and self.verbose > 0:
print("Refining scales with LBFGS...")
if self.verbose > 2:
assert torch.all(
torch.isfinite(fcalc)
), "Non-finite values found in fcalc during scale optimization."
# Run optimization
for step in range(nsteps):
state.step(optimizer, context="scaler.refine_lbfgs")
# Evaluate metrics
with torch.no_grad():
hkl, fobs, sigma, rfree_mask = self._data()
fcalc_scaled = self.forward(fcalc)
xray_work = nll_xray(fobs[rfree_mask], fcalc_scaled[rfree_mask], sigma[rfree_mask])
xray_test = nll_xray(fobs[~rfree_mask], fcalc_scaled[~rfree_mask], sigma[~rfree_mask])
rwork, rfree_val = get_rfactors(
torch.abs(fobs), torch.abs(fcalc_scaled), rfree_mask
)
metrics["steps"].append(step + 1)
metrics["xray_work"].append(xray_work.item())
metrics["xray_test"].append(xray_test.item())
metrics["rwork"].append(rwork)
metrics["rfree"].append(rfree_val)
if verbose and self.verbose > 2:
print(
f" Step {step+1}/{nsteps}: "
f"Rwork={rwork:.4f}, Rfree={rfree_val:.4f}, "
f"NLL_work={xray_work.item():.2f}, NLL_test={xray_test.item():.2f}"
)
# Estimate per-shell effective sigmas from residuals (SIGMAA-style)
# This makes the X-ray likelihood robust to miscalibrated experimental sigmas.
self.estimate_sigma_eff(fcalc)
if verbose and self.verbose > 0:
with torch.no_grad():
print(
f"Scale refinement complete. rwork: {rwork:.4f}, rfree: {rfree_val:.4f}\n"
)
print("Final Scale Parameters: ")
for name, param in self.named_parameters():
if param.requires_grad:
print(f" {name}: {param.data}")
return metrics
[docs]
def estimate_sigma_eff(
self,
fcalc: torch.Tensor,
max_inflation: float = 2.0,
) -> torch.Tensor:
"""
Estimate per-resolution-shell effective sigmas from current residuals.
Pannu & Read / SIGMAA-style correction: detects miscalibrated
experimental sigmas by comparing residual variance to the claimed
variance, per resolution bin.
For each resolution bin:
D_bin = < (F_obs - k * |F_calc|)^2 > (using work set)
ratio_bin = sqrt(D_bin / <sigma_F^2>)
ratio_capped = clamp(ratio_bin, 1.0, max_inflation)
sigma_eff = sigma_F * ratio_capped
**Why the cap?** At the start of refinement the model is bad, so
residuals are dominated by *model error* (which is fixable by
refining), not noise. Uncapped inflation creates a vicious cycle:
bad model -> huge sigma_eff -> weak data gradient -> bad model.
Capping at ``max_inflation`` (default 2.0, i.e. sigmas can grow
at most 2x) prevents runaway while still correcting genuinely
under-estimated sigmas.
As the model improves, residuals shrink and the ratio drops toward
1, so sigma_eff converges to the raw sigma (good calibration).
Uses the work set only so the test set doesn't leak into
sigma estimation.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors (complex, unscaled).
max_inflation : float, optional
Maximum allowed ratio sigma_eff / sigma_raw. Default 2.0.
Returns
-------
torch.Tensor
Per-reflection effective sigmas, shape (N,).
"""
with torch.no_grad():
hkl, fobs_raw, sigma_raw, rfree_mask = self._data(mask=False)
# Apply scaling to F_calc
fcalc_scaled = self.forward(fcalc).squeeze(0) if fcalc.ndim == 1 else self.forward(fcalc)
if fcalc_scaled.ndim > 1:
fcalc_scaled = fcalc_scaled.squeeze(0)
fcalc_amp = torch.abs(fcalc_scaled).to(fobs_raw.dtype)
# Work set only (rfree=True = work in this codebase convention)
validity = self._data.masks().to(torch.bool)
work_mask = validity & rfree_mask.bool()
bins_work = self.bins[work_mask].to(torch.int64)
fobs_work = fobs_raw[work_mask]
fcalc_work = fcalc_amp[work_mask]
sigma_work = sigma_raw[work_mask]
# Residuals using F_calc directly (alpha=1 assumed post-scaling)
residuals_sq = (fobs_work - fcalc_work) ** 2
# Per-bin sums
sum_d = torch.zeros(self.nbins, device=self.device, dtype=fobs_raw.dtype)
sum_s2 = torch.zeros(self.nbins, device=self.device, dtype=fobs_raw.dtype)
counts = torch.zeros(self.nbins, device=self.device, dtype=fobs_raw.dtype)
sum_d = torch.scatter_add(sum_d, 0, bins_work, residuals_sq)
sum_s2 = torch.scatter_add(sum_s2, 0, bins_work, sigma_work ** 2)
counts = torch.scatter_add(
counts, 0, bins_work, torch.ones_like(fobs_work)
)
# Per-bin empirical residual variance and mean raw variance
d_per_bin = sum_d / counts.clamp(min=1.0)
mean_sigma2_per_bin = sum_s2 / counts.clamp(min=1.0)
# Ratio sigma_eff / sigma_raw, capped to [1, max_inflation]
ratio_per_bin = torch.sqrt(
(d_per_bin / mean_sigma2_per_bin.clamp(min=1e-12)).clamp(min=1e-12)
)
ratio_per_bin = torch.clamp(ratio_per_bin, 1.0, float(max_inflation))
# sigma_eff per reflection = sigma_raw * ratio_for_its_bin
all_bins = self.bins.to(torch.int64)
ratio_per_refl = ratio_per_bin[all_bins]
sigma_eff_all = sigma_raw * ratio_per_refl
# Store per-bin representative sigma_eff (using mean raw sigma in bin)
sigma_eff_per_bin = torch.sqrt(mean_sigma2_per_bin.clamp(min=1e-12)) * ratio_per_bin
self.sigma_eff_per_bin.copy_(sigma_eff_per_bin)
self.sigma_eff.copy_(sigma_eff_all)
if self.verbose > 1:
mean_raw = sigma_raw[work_mask].mean().item()
mean_eff = sigma_eff_all[work_mask].mean().item()
print(
f" sigma_eff estimation: mean raw={mean_raw:.3f}, "
f"mean effective={mean_eff:.3f}, "
f"ratio={mean_eff/max(mean_raw,1e-6):.2f}, "
f"per-bin ratios={ratio_per_bin.cpu().tolist()}"
)
return sigma_eff_all
[docs]
def forward(
self,
fcalc: torch.Tensor,
use_mask: bool = True,
f_sol_override: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass for the ScalerBase module.
Parameters
----------
fcalc : torch.Tensor
Calculated structure factors. Expected shape (N,), an additional
dimension for batch is possible. N should match the full HKL size.
use_mask : bool, default True
Deprecated parameter, kept for backward compatibility.
f_sol_override : torch.Tensor, optional
Pre-computed raw solvent structure factors. When provided, these
replace the internally-cached ``_f_sol_raw``. The scaler's
k_sol / B_sol / phase damping is still applied. This is used by
``CollectionScaler`` to supply mixed (fraction-weighted) solvent
contributions.
Returns
-------
torch.Tensor
Scaled structure factors of same shape as input.
"""
batched = True
if fcalc.ndim == 1:
fcalc = fcalc.unsqueeze(0)
batched = False
# Determine if we should mask internally or work with full arrays
n_full = len(self.bins)
n_fcalc = fcalc.shape[1]
if n_fcalc == n_full:
apply_internal_mask = False
else:
apply_internal_mask = True
mask = self._data.masks().to(torch.bool)
if hasattr(self, "U"):
anisotropy_factors = self.anisotropy_correction()
aniso_correction = (
anisotropy_factors[mask] if apply_internal_mask else anisotropy_factors
)
else:
aniso_correction = torch.tensor(1.0, device=self.device, dtype=fcalc.dtype)
if f_sol_override is not None:
self._f_sol_raw = f_sol_override
if hasattr(self, "solvent") and self.solvent is not None:
# Lazily cache raw solvent SFs (FFT of mask) — only recomputed
# when invalidated via _f_sol_raw = None (e.g. after update_solvent)
if self._f_sol_raw is None:
self._f_sol_raw = self.solvent.get_rec_solvent(self.hkl)
f_sol_raw = self._f_sol_raw[mask] if apply_internal_mask else self._f_sol_raw
if hasattr(self, "log_kmask"):
kmask = torch.exp(self.log_kmask.clamp(min=-10.0, max=10.0))
kmask = torch.clamp(kmask, min=0.0, max=10.0)
bins_to_use = self.bins[mask] if apply_internal_mask else self.bins
# Use index_add-backward gather so the gradient back into
# log_kmask is a single index_add_ rather than PyTorch's
# radix-sort + scatter index_put_ path.
kmask_per_refl = gather_with_index_add(kmask, bins_to_use)
f_sol = kmask_per_refl * f_sol_raw
else:
# Inline solvent scaling: k_sol * exp(i*phase) * exp(-B*s²) * f_mask
# Uses precomputed self._s_half_sq instead of recomputing scattering vectors
sol = self.solvent
k_sol = torch.exp(sol.log_k_solvent.clamp(min=-10.0, max=10.0))
s_half_sq = self._s_half_sq[mask] if apply_internal_mask else self._s_half_sq
b_factor = torch.exp(
(-sol.b_solvent.clamp(min=-500.0, max=500.0) * s_half_sq).clamp(min=-10.0, max=10.0)
)
if sol.optimize_phase:
# 1j is a Python complex literal -> promotes to complex128.
# Build the phase factor in the configured complex dtype instead.
j = torch.tensor(1j, dtype=get_complex_dtype(), device=self.device)
f_sol_raw = f_sol_raw * torch.exp(j * sol.phase_offset)
f_sol = k_sol * f_sol_raw * b_factor
else:
f_sol = torch.tensor(0.0, device=self.device, dtype=fcalc.dtype)
if hasattr(self, "log_scale") and self.log_scale is not None:
bins_to_use = self.bins[mask] if apply_internal_mask else self.bins
# Use index_add-backward gather: log_scale is a tiny (~nbins)
# accumulator and the default index_put_ backward is bottlenecked
# by a cub::DeviceRadixSort over all ~N_refl indices (see profile
# data on A100/3GR5, ~370 us/iter pre-fix).
K_overall = torch.exp(
gather_with_index_add(self.log_scale, bins_to_use)
.clamp(min=-10.0, max=10.0)
)
else:
K_overall = torch.tensor(1.0, device=self.device, dtype=fcalc.dtype)
if hasattr(self, "bin_wise_bfactor") and self.bin_wise_bfactor is not None:
bfactor_factors = self.bin_wise_bfactor_correction()
b_overall = (
bfactor_factors[mask] if apply_internal_mask else bfactor_factors
)
else:
b_overall = torch.tensor(1.0, device=self.device, dtype=fcalc.dtype)
fcalc = (
K_overall.unsqueeze(0)
* b_overall.unsqueeze(0)
* (aniso_correction.unsqueeze(0) * fcalc + f_sol.unsqueeze(0))
)
if not batched:
fcalc = fcalc.squeeze(0)
return fcalc
[docs]
def state_dict(self, destination=None, prefix="", keep_vars=False):
"""
Return a dictionary containing the complete state of the ScalerBase.
This includes:
- All registered buffers and parameters (via parent class)
- Scaler-specific metadata (nbins, etc.)
- Solvent model state (if set)
Note: Data reference is NOT saved (managed separately).
Parameters
----------
destination : dict, optional
Optional dict to populate.
prefix : str, default ''
Prefix for parameter names.
keep_vars : bool, default False
Whether to keep variables in computational graph.
Returns
-------
dict
Complete state dictionary.
"""
state = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
state[prefix + "nbins"] = self.nbins
state[prefix + "verbose"] = self.verbose
if hasattr(self, "solvent") and self.solvent is not None:
state[prefix + "solvent"] = self.solvent.state_dict()
return state
[docs]
def load_state_dict(self, state_dict, strict=True):
"""
Load the ScalerBase state from a dictionary.
Note: This assumes data is already set via __init__ or set_data().
Parameters
----------
state_dict : dict
Dictionary containing scaler state.
strict : bool, default True
Whether to strictly enforce that keys match.
"""
self.nbins = state_dict.pop("nbins", 20)
self.verbose = state_dict.pop("verbose", 1)
# Legacy state dicts may contain a "frozen" entry; drop it silently.
state_dict.pop("frozen", None)
solvent_state = state_dict.pop("solvent", None)
result = super().load_state_dict(state_dict, strict=strict)
if solvent_state is not None and hasattr(self, "solvent") and self.solvent is not None:
self.solvent.load_state_dict(solvent_state)
return result
[docs]
def save_state(self, path: str):
"""
Save the complete state of the scaler to a file.
Parameters
----------
path : str
Path to save the state dictionary to.
"""
torch.save(self.state_dict(), path)
if self.verbose > 0:
print(f"Saved scaler state to {path}")
[docs]
def load_state(self, path: str, strict: bool = True):
"""
Load the complete state of the scaler from a file.
Parameters
----------
path : str
Path to load the state dictionary from.
strict : bool, default True
Whether to strictly enforce that keys match.
"""
state_dict = torch.load(path, map_location=self.device, weights_only=False)
self.load_state_dict(state_dict, strict=strict)
if self.verbose > 0:
print(f"Loaded scaler state from {path}")