torchref.scaling.scaler_base module

Base scaler class for crystallographic scaling without model dependency.

This module provides ScalerBase, a class that implements all scaling functionality but does NOT maintain a reference to a Model object. All methods that require calculated structure factors (F_calc) take them as input arguments.

This allows the scaler to be used independently of any specific model implementation, making it suitable for use cases like: - Molecular replacement where F_calc comes from external sources - Testing and validation with precomputed structure factors - Custom model implementations

class torchref.scaling.scaler_base.ScalerBase(data=None, nbins=20, verbose=1, device=None)[source]

Bases: DeviceMixin, DebugMixin, Module

Base scaler class for crystallographic scaling without model dependency.

All methods that require calculated structure factors (F_calc) take them as input arguments. This allows the scaler to be used independently of any specific model implementation.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    scaler = ScalerBase()  # Creates empty shell
    scaler.load_state_dict(torch.load('scaler.pt'))
    
  2. Full initialization with data:

    scaler = ScalerBase(data=reflection_data, nbins=20)
    scaler.initialize(fcalc)
    
Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, default: configured device.current) – Computation device.

device

Current computation device.

Type:

torch.device

nbins

Number of resolution bins.

Type:

int

__init__(data=None, nbins=20, verbose=1, device=None)[source]

Initialize ScalerBase.

If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from data (if given) or the configured default via torchref.utils.resolve_device(). An explicit value forces data onto that device.

set_data(data)[source]

Set data reference after empty initialization.

This is useful when loading from state_dict and then needing to reconnect to a data object.

Parameters:

data (ReflectionData) – ReflectionData object with observed data.

initialize(fcalc)[source]

Initialize scaling parameters using provided F_calc.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

property hkl

Get HKL indices from data.

calc_initial_scale(fcalc)[source]

Calculate the initial scale factor based on the ratio of observed to calculated structure factors.

Excludes reflections with negative intensities to avoid bias from French-Wilson conversion.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

The log scale parameter for each resolution bin.

Return type:

torch.nn.Parameter

setup_anisotropy_correction()[source]

Initialize anisotropic correction parameters.

anisotropy_correction()[source]

Compute anisotropic correction factors.

Returns:

Anisotropic correction factors for each reflection.

Return type:

torch.Tensor

fit_anisotropy(fcalc, nsteps=100)[source]

Fit anisotropic correction using provided F_calc.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • nsteps (int, default 100) – Number of optimization steps.

set_solvent_model(solvent_model)[source]

Set a pre-configured SolventModel for solvent contribution.

The SolventModel must be initialized externally (requires a Model object).

Parameters:

solvent_model (SolventModel) – Pre-configured solvent model that can compute solvent structure factors.

setup_binwise_solvent_scale()[source]

Setup bin-wise solvent scaling (Phenix-style kmask per bin).

This allows finer control over solvent contribution per resolution bin, which is more flexible than a single global B_sol parameter.

fit_all_scales(fcalc)[source]

Fit all scale parameters using provided F_calc.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

fit_simple(fobs, fcalc)[source]

Fit a single global scale factor analytically (least-squares).

This is the simple scaling approach:

k = sum(|F_obs||F_calc|) / sum(|F_calc|²)

Useful for rigid body refinement where only an overall scale is needed.

Parameters:
  • fobs (torch.Tensor) – Observed structure factor amplitudes.

  • fcalc (torch.Tensor) – Calculated structure factors (complex).

get_scale()[source]

Get the current overall scale factor value.

Returns the mean scale factor across all bins.

Returns:

Current scale factor (not log).

Return type:

float

rfactor(fcalc)[source]

Calculate the R-factor between observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

R-work and R-free values.

Return type:

tuple

bin_wise_rfactor(fcalc)[source]

Calculate the bin-wise R-factor between observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

  • mean_res_per_bin (torch.Tensor) – Mean resolution per bin.

  • rwork_per_bin (torch.Tensor) – R-work per bin.

  • rfree_per_bin (torch.Tensor) – R-free per bin.

setup_bin_wise_bfactor()[source]

Initialize bin-wise B-factor correction parameters.

bin_wise_bfactor_correction()[source]

Compute bin-wise B-factor correction factors.

Returns:

B-factor correction factors for each reflection.

Return type:

torch.Tensor

get_binwise_mean_intensity(fcalc)[source]

Get bin-wise mean intensities for observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

Mean observed intensity, mean calculated intensity, and mean resolution per bin.

Return type:

tuple

screen_solvent_params(fcalc, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]

Screen solvent parameters (k_sol, B_sol) using grid search.

The bulk solvent contributes primarily at low resolution. Fitting on low-resolution reflections only (fit_on_low_res_only=True) prevents high-resolution reflections from dominating the optimization and pushing B_sol too low.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • steps (int, default 15) – Number of grid points for each parameter.

  • use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily since solvent primarily contributes at low resolution.

  • low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.

  • fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.

  • low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.

refine_lbfgs(fcalc, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS optimizer.

This method optimizes the anisotropic scaling and B-factor parameters that relate calculated structure factors to observed structure factors. Uses the L-BFGS quasi-Newton optimization method for fast convergence.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • nsteps (int, default 3) – Number of LBFGS steps.

  • lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int, default 200) – Maximum iterations per line search.

  • history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.

  • verbose (bool, default True) – Print progress information.

Returns:

Dictionary with refinement metrics including steps, xray_work, xray_test, rwork, rfree.

Return type:

dict

estimate_sigma_eff(fcalc, max_inflation=2.0)[source]

Estimate per-resolution-shell effective sigmas from current residuals.

Pannu & Read / SIGMAA-style correction: detects miscalibrated experimental sigmas by comparing residual variance to the claimed variance, per resolution bin.

For each resolution bin:

D_bin = < (F_obs - k * |F_calc|)^2 > (using work set) ratio_bin = sqrt(D_bin / <sigma_F^2>) ratio_capped = clamp(ratio_bin, 1.0, max_inflation) sigma_eff = sigma_F * ratio_capped

Why the cap? At the start of refinement the model is bad, so residuals are dominated by model error (which is fixable by refining), not noise. Uncapped inflation creates a vicious cycle: bad model -> huge sigma_eff -> weak data gradient -> bad model. Capping at max_inflation (default 2.0, i.e. sigmas can grow at most 2x) prevents runaway while still correcting genuinely under-estimated sigmas.

As the model improves, residuals shrink and the ratio drops toward 1, so sigma_eff converges to the raw sigma (good calibration).

Uses the work set only so the test set doesn’t leak into sigma estimation.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex, unscaled).

  • max_inflation (float, optional) – Maximum allowed ratio sigma_eff / sigma_raw. Default 2.0.

Returns:

Per-reflection effective sigmas, shape (N,).

Return type:

torch.Tensor

forward(fcalc, use_mask=True, f_sol_override=None)[source]

Forward pass for the ScalerBase module.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors. Expected shape (N,), an additional dimension for batch is possible. N should match the full HKL size.

  • use_mask (bool, default True) – Deprecated parameter, kept for backward compatibility.

  • f_sol_override (torch.Tensor, optional) – Pre-computed raw solvent structure factors. When provided, these replace the internally-cached _f_sol_raw. The scaler’s k_sol / B_sol / phase damping is still applied. This is used by CollectionScaler to supply mixed (fraction-weighted) solvent contributions.

Returns:

Scaled structure factors of same shape as input.

Return type:

torch.Tensor

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the ScalerBase.

This includes:

  • All registered buffers and parameters (via parent class)

  • Scaler-specific metadata (nbins, etc.)

  • Solvent model state (if set)

Note: Data reference is NOT saved (managed separately).

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, default '') – Prefix for parameter names.

  • keep_vars (bool, default False) – Whether to keep variables in computational graph.

Returns:

Complete state dictionary.

Return type:

dict

load_state_dict(state_dict, strict=True)[source]

Load the ScalerBase state from a dictionary.

Note: This assumes data is already set via __init__ or set_data().

Parameters:
  • state_dict (dict) – Dictionary containing scaler state.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

save_state(path)[source]

Save the complete state of the scaler to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the scaler from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, default True) – Whether to strictly enforce that keys match.