torchref.scaling.scaler_base module
Base scaler class for crystallographic scaling without model dependency.
This module provides ScalerBase, a class that implements all scaling functionality but does NOT maintain a reference to a Model object. All methods that require calculated structure factors (F_calc) take them as input arguments.
This allows the scaler to be used independently of any specific model implementation, making it suitable for use cases like: - Molecular replacement where F_calc comes from external sources - Testing and validation with precomputed structure factors - Custom model implementations
- class torchref.scaling.scaler_base.ScalerBase(data=None, nbins=20, verbose=1, device=None)[source]
Bases:
DeviceMixin,DebugMixin,ModuleBase scaler class for crystallographic scaling without model dependency.
All methods that require calculated structure factors (F_calc) take them as input arguments. This allows the scaler to be used independently of any specific model implementation.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
scaler = ScalerBase() # Creates empty shell scaler.load_state_dict(torch.load('scaler.pt'))
Full initialization with data:
scaler = ScalerBase(data=reflection_data, nbins=20) scaler.initialize(fcalc)
- Parameters:
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, default: configured device.current) – Computation device.
- device
Current computation device.
- Type:
- __init__(data=None, nbins=20, verbose=1, device=None)[source]
Initialize ScalerBase.
If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, optional) – Computation device. If
None, derived fromdata(if given) or the configured default viatorchref.utils.resolve_device(). An explicit value forcesdataonto that device.
- set_data(data)[source]
Set data reference after empty initialization.
This is useful when loading from state_dict and then needing to reconnect to a data object.
- Parameters:
data (ReflectionData) – ReflectionData object with observed data.
- initialize(fcalc)[source]
Initialize scaling parameters using provided F_calc.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- property hkl
Get HKL indices from data.
- calc_initial_scale(fcalc)[source]
Calculate the initial scale factor based on the ratio of observed to calculated structure factors.
Excludes reflections with negative intensities to avoid bias from French-Wilson conversion.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
The log scale parameter for each resolution bin.
- Return type:
torch.nn.Parameter
- anisotropy_correction()[source]
Compute anisotropic correction factors.
- Returns:
Anisotropic correction factors for each reflection.
- Return type:
- fit_anisotropy(fcalc, nsteps=100)[source]
Fit anisotropic correction using provided F_calc.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
nsteps (int, default 100) – Number of optimization steps.
- set_solvent_model(solvent_model)[source]
Set a pre-configured SolventModel for solvent contribution.
The SolventModel must be initialized externally (requires a Model object).
- Parameters:
solvent_model (SolventModel) – Pre-configured solvent model that can compute solvent structure factors.
- setup_binwise_solvent_scale()[source]
Setup bin-wise solvent scaling (Phenix-style kmask per bin).
This allows finer control over solvent contribution per resolution bin, which is more flexible than a single global B_sol parameter.
- fit_all_scales(fcalc)[source]
Fit all scale parameters using provided F_calc.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- fit_simple(fobs, fcalc)[source]
Fit a single global scale factor analytically (least-squares).
- This is the simple scaling approach:
k = sum(|F_obs||F_calc|) / sum(|F_calc|²)
Useful for rigid body refinement where only an overall scale is needed.
- Parameters:
fobs (torch.Tensor) – Observed structure factor amplitudes.
fcalc (torch.Tensor) – Calculated structure factors (complex).
- get_scale()[source]
Get the current overall scale factor value.
Returns the mean scale factor across all bins.
- Returns:
Current scale factor (not log).
- Return type:
- rfactor(fcalc)[source]
Calculate the R-factor between observed and calculated structure factors.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
R-work and R-free values.
- Return type:
- bin_wise_rfactor(fcalc)[source]
Calculate the bin-wise R-factor between observed and calculated structure factors.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
mean_res_per_bin (torch.Tensor) – Mean resolution per bin.
rwork_per_bin (torch.Tensor) – R-work per bin.
rfree_per_bin (torch.Tensor) – R-free per bin.
- bin_wise_bfactor_correction()[source]
Compute bin-wise B-factor correction factors.
- Returns:
B-factor correction factors for each reflection.
- Return type:
- get_binwise_mean_intensity(fcalc)[source]
Get bin-wise mean intensities for observed and calculated structure factors.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
Mean observed intensity, mean calculated intensity, and mean resolution per bin.
- Return type:
- screen_solvent_params(fcalc, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]
Screen solvent parameters (k_sol, B_sol) using grid search.
The bulk solvent contributes primarily at low resolution. Fitting on low-resolution reflections only (fit_on_low_res_only=True) prevents high-resolution reflections from dominating the optimization and pushing B_sol too low.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
steps (int, default 15) – Number of grid points for each parameter.
use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily since solvent primarily contributes at low resolution.
low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.
fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.
low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.
- refine_lbfgs(fcalc, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]
Refine scale parameters using LBFGS optimizer.
This method optimizes the anisotropic scaling and B-factor parameters that relate calculated structure factors to observed structure factors. Uses the L-BFGS quasi-Newton optimization method for fast convergence.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
nsteps (int, default 3) – Number of LBFGS steps.
lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).
max_iter (int, default 200) – Maximum iterations per line search.
history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.
verbose (bool, default True) – Print progress information.
- Returns:
Dictionary with refinement metrics including steps, xray_work, xray_test, rwork, rfree.
- Return type:
- estimate_sigma_eff(fcalc, max_inflation=2.0)[source]
Estimate per-resolution-shell effective sigmas from current residuals.
Pannu & Read / SIGMAA-style correction: detects miscalibrated experimental sigmas by comparing residual variance to the claimed variance, per resolution bin.
For each resolution bin:
D_bin = < (F_obs - k * |F_calc|)^2 > (using work set) ratio_bin = sqrt(D_bin / <sigma_F^2>) ratio_capped = clamp(ratio_bin, 1.0, max_inflation) sigma_eff = sigma_F * ratio_capped
Why the cap? At the start of refinement the model is bad, so residuals are dominated by model error (which is fixable by refining), not noise. Uncapped inflation creates a vicious cycle: bad model -> huge sigma_eff -> weak data gradient -> bad model. Capping at
max_inflation(default 2.0, i.e. sigmas can grow at most 2x) prevents runaway while still correcting genuinely under-estimated sigmas.As the model improves, residuals shrink and the ratio drops toward 1, so sigma_eff converges to the raw sigma (good calibration).
Uses the work set only so the test set doesn’t leak into sigma estimation.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex, unscaled).
max_inflation (float, optional) – Maximum allowed ratio sigma_eff / sigma_raw. Default 2.0.
- Returns:
Per-reflection effective sigmas, shape (N,).
- Return type:
- forward(fcalc, use_mask=True, f_sol_override=None)[source]
Forward pass for the ScalerBase module.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors. Expected shape (N,), an additional dimension for batch is possible. N should match the full HKL size.
use_mask (bool, default True) – Deprecated parameter, kept for backward compatibility.
f_sol_override (torch.Tensor, optional) – Pre-computed raw solvent structure factors. When provided, these replace the internally-cached
_f_sol_raw. The scaler’s k_sol / B_sol / phase damping is still applied. This is used byCollectionScalerto supply mixed (fraction-weighted) solvent contributions.
- Returns:
Scaled structure factors of same shape as input.
- Return type:
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the ScalerBase.
This includes:
All registered buffers and parameters (via parent class)
Scaler-specific metadata (nbins, etc.)
Solvent model state (if set)
Note: Data reference is NOT saved (managed separately).
- load_state_dict(state_dict, strict=True)[source]
Load the ScalerBase state from a dictionary.
Note: This assumes data is already set via __init__ or set_data().