torchref.scaling.scaler module

A class for scaling and post corrections of scattering factors.

Currently implements: - Overall scale per resolution bin - B-factor per resolution bin - Anisotropy correction - Solvent model correction

This module provides the full-featured Scaler class that maintains a reference to a Model object. For a model-independent scaler, see ScalerBase.

class torchref.scaling.scaler.Scaler(model=None, data=None, nbins=20, verbose=1, device=None)[source]

Bases: ScalerBase

Full-featured scaler with Model integration.

Extends ScalerBase by maintaining a reference to a Model object and providing convenience methods that automatically compute F_calc when not provided.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    scaler = Scaler()  # Creates empty shell
    scaler.load_state_dict(torch.load('scaler.pt'))
    
  2. Full initialization with model and data:

    scaler = Scaler(model, reflection_data, nbins=20)
    scaler.initialize()
    
Parameters:
  • model (Model, optional) – Model object for structure factor calculation.

  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, default: configured device.current) – Computation device.

device

Current computation device.

Type:

torch.device

nbins

Number of resolution bins.

Type:

int

__init__(model=None, data=None, nbins=20, verbose=1, device=None)[source]

Initialize Scaler.

If model and data are provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • model (Model, optional) – Model object for structure factor calculation.

  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from model then data (model wins on mismatch); otherwise forces both onto the explicit device. See torchref.utils.resolve_device().

property model

Access the model object (not a registered submodule).

set_model_and_data(model, data)[source]

Set model and data references after empty initialization.

This is useful when loading from state_dict and then needing to reconnect to model/data objects.

Parameters:
  • model (Model) – Model object for structure factor calculation.

  • data (ReflectionData) – ReflectionData object with observed data.

initialize(fcalc=None)[source]

Initialize scaling parameters.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

compute_fcalc()[source]

Compute F_calc from internal model.

Returns:

Calculated structure factors.

Return type:

torch.Tensor

Raises:

RuntimeError – If no model is set.

calc_initial_scale(fcalc=None)[source]

Calculate initial scale factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

The log scale parameter for each resolution bin.

Return type:

torch.nn.Parameter

fit_anisotropy(fcalc=None, nsteps=100)[source]

Fit anisotropic correction.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • nsteps (int, default 100) – Number of optimization steps.

setup_solvent()[source]

Setup solvent model using internal model.

Creates a SolventModel using the internal model reference.

fit_all_scales(fcalc=None)[source]

Fit all scale parameters.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

screen_solvent_params(fcalc=None, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]

Screen solvent parameters using grid search.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • steps (int, default 15) – Number of grid points for each parameter.

  • use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily.

  • low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.

  • fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.

  • low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.

refine_lbfgs(fcalc=None, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS optimizer.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • nsteps (int, default 3) – Number of LBFGS steps.

  • lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int, default 200) – Maximum iterations per line search.

  • history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.

  • verbose (bool, default True) – Print progress information.

Returns:

Dictionary with refinement metrics.

Return type:

dict

rfactor(fcalc=None)[source]

Calculate R-factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

R-work and R-free values.

Return type:

tuple

bin_wise_rfactor(fcalc=None)[source]

Calculate bin-wise R-factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

  • mean_res_per_bin (torch.Tensor) – Mean resolution per bin.

  • rwork_per_bin (torch.Tensor) – R-work per bin.

  • rfree_per_bin (torch.Tensor) – R-free per bin.

get_binwise_mean_intensity(fcalc=None)[source]

Get bin-wise mean intensities.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

Mean observed intensity, mean calculated intensity, and mean resolution per bin.

Return type:

tuple

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the Scaler.

This includes:

  • All registered buffers and parameters (via parent class)

  • Scaler-specific metadata (nbins, etc.)

  • Solvent model state (if initialized)

Note: Model and data references are NOT saved (managed separately).

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, default '') – Prefix for parameter names.

  • keep_vars (bool, default False) – Whether to keep variables in computational graph.

Returns:

Complete state dictionary.

Return type:

dict

load_state_dict(state_dict, strict=True)[source]

Load the Scaler state from a dictionary.

Note: This assumes model and data are already set via __init__ or assignment.

Parameters:
  • state_dict (dict) – Dictionary containing scaler state.

  • strict (bool, default True) – Whether to strictly enforce that keys match.