torchref.scaling package

Structure factor scaling module for TorchRef.

This module provides classes for scaling calculated structure factors to match observed data, including: - Overall and anisotropic scale factors - Bulk solvent contribution modeling

Classes

ScalerBase

Base scaler class that does not require a Model object. All methods that need F_calc take it as an input argument.

Scaler

Full-featured scaler with Model integration. Extends ScalerBase with convenience methods that auto-compute F_calc.

SolventModel

Models bulk solvent contribution to structure factors using flat solvent model with k_sol and B_sol parameters.

Example

from torchref.scaling import Scaler, ScalerBase, SolventModel

# Using Scaler with a model (auto-computes F_calc)
scaler = Scaler(model, data, nbins=20)
scaler.initialize()
F_calc_scaled = scaler(F_calc)

# Using ScalerBase without a model (requires F_calc as input)
scaler_base = ScalerBase(data=data, nbins=20)
scaler_base.initialize(fcalc)
F_calc_scaled = scaler_base(fcalc)
class torchref.scaling.Scaler(model=None, data=None, nbins=20, verbose=1, device=None)[source]

Bases: ScalerBase

Full-featured scaler with Model integration.

Extends ScalerBase by maintaining a reference to a Model object and providing convenience methods that automatically compute F_calc when not provided.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    scaler = Scaler()  # Creates empty shell
    scaler.load_state_dict(torch.load('scaler.pt'))
    
  2. Full initialization with model and data:

    scaler = Scaler(model, reflection_data, nbins=20)
    scaler.initialize()
    
Parameters:
  • model (Model, optional) – Model object for structure factor calculation.

  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, default: configured device.current) – Computation device.

device

Current computation device.

Type:

torch.device

nbins

Number of resolution bins.

Type:

int

__init__(model=None, data=None, nbins=20, verbose=1, device=None)[source]

Initialize Scaler.

If model and data are provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • model (Model, optional) – Model object for structure factor calculation.

  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from model then data (model wins on mismatch); otherwise forces both onto the explicit device. See torchref.utils.resolve_device().

property model

Access the model object (not a registered submodule).

set_model_and_data(model, data)[source]

Set model and data references after empty initialization.

This is useful when loading from state_dict and then needing to reconnect to model/data objects.

Parameters:
  • model (Model) – Model object for structure factor calculation.

  • data (ReflectionData) – ReflectionData object with observed data.

initialize(fcalc=None)[source]

Initialize scaling parameters.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

compute_fcalc()[source]

Compute F_calc from internal model.

Returns:

Calculated structure factors.

Return type:

torch.Tensor

Raises:

RuntimeError – If no model is set.

calc_initial_scale(fcalc=None)[source]

Calculate initial scale factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

The log scale parameter for each resolution bin.

Return type:

torch.nn.Parameter

fit_anisotropy(fcalc=None, nsteps=100)[source]

Fit anisotropic correction.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • nsteps (int, default 100) – Number of optimization steps.

setup_solvent()[source]

Setup solvent model using internal model.

Creates a SolventModel using the internal model reference.

fit_all_scales(fcalc=None)[source]

Fit all scale parameters.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

screen_solvent_params(fcalc=None, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]

Screen solvent parameters using grid search.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • steps (int, default 15) – Number of grid points for each parameter.

  • use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily.

  • low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.

  • fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.

  • low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.

refine_lbfgs(fcalc=None, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS optimizer.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • nsteps (int, default 3) – Number of LBFGS steps.

  • lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int, default 200) – Maximum iterations per line search.

  • history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.

  • verbose (bool, default True) – Print progress information.

Returns:

Dictionary with refinement metrics.

Return type:

dict

rfactor(fcalc=None)[source]

Calculate R-factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

R-work and R-free values.

Return type:

tuple

bin_wise_rfactor(fcalc=None)[source]

Calculate bin-wise R-factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

  • mean_res_per_bin (torch.Tensor) – Mean resolution per bin.

  • rwork_per_bin (torch.Tensor) – R-work per bin.

  • rfree_per_bin (torch.Tensor) – R-free per bin.

get_binwise_mean_intensity(fcalc=None)[source]

Get bin-wise mean intensities.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

Mean observed intensity, mean calculated intensity, and mean resolution per bin.

Return type:

tuple

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the Scaler.

This includes:

  • All registered buffers and parameters (via parent class)

  • Scaler-specific metadata (nbins, etc.)

  • Solvent model state (if initialized)

Note: Model and data references are NOT saved (managed separately).

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, default '') – Prefix for parameter names.

  • keep_vars (bool, default False) – Whether to keep variables in computational graph.

Returns:

Complete state dictionary.

Return type:

dict

load_state_dict(state_dict, strict=True)[source]

Load the Scaler state from a dictionary.

Note: This assumes model and data are already set via __init__ or assignment.

Parameters:
  • state_dict (dict) – Dictionary containing scaler state.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

class torchref.scaling.ScalerBase(data=None, nbins=20, verbose=1, device=None)[source]

Bases: DeviceMixin, DebugMixin, Module

Base scaler class for crystallographic scaling without model dependency.

All methods that require calculated structure factors (F_calc) take them as input arguments. This allows the scaler to be used independently of any specific model implementation.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    scaler = ScalerBase()  # Creates empty shell
    scaler.load_state_dict(torch.load('scaler.pt'))
    
  2. Full initialization with data:

    scaler = ScalerBase(data=reflection_data, nbins=20)
    scaler.initialize(fcalc)
    
Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, default: configured device.current) – Computation device.

device

Current computation device.

Type:

torch.device

nbins

Number of resolution bins.

Type:

int

__init__(data=None, nbins=20, verbose=1, device=None)[source]

Initialize ScalerBase.

If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from data (if given) or the configured default via torchref.utils.resolve_device(). An explicit value forces data onto that device.

set_data(data)[source]

Set data reference after empty initialization.

This is useful when loading from state_dict and then needing to reconnect to a data object.

Parameters:

data (ReflectionData) – ReflectionData object with observed data.

initialize(fcalc)[source]

Initialize scaling parameters using provided F_calc.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

property hkl

Get HKL indices from data.

calc_initial_scale(fcalc)[source]

Calculate the initial scale factor based on the ratio of observed to calculated structure factors.

Excludes reflections with negative intensities to avoid bias from French-Wilson conversion.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

The log scale parameter for each resolution bin.

Return type:

torch.nn.Parameter

setup_anisotropy_correction()[source]

Initialize anisotropic correction parameters.

anisotropy_correction()[source]

Compute anisotropic correction factors.

Returns:

Anisotropic correction factors for each reflection.

Return type:

torch.Tensor

fit_anisotropy(fcalc, nsteps=100)[source]

Fit anisotropic correction using provided F_calc.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • nsteps (int, default 100) – Number of optimization steps.

set_solvent_model(solvent_model)[source]

Set a pre-configured SolventModel for solvent contribution.

The SolventModel must be initialized externally (requires a Model object).

Parameters:

solvent_model (SolventModel) – Pre-configured solvent model that can compute solvent structure factors.

setup_binwise_solvent_scale()[source]

Setup bin-wise solvent scaling (Phenix-style kmask per bin).

This allows finer control over solvent contribution per resolution bin, which is more flexible than a single global B_sol parameter.

fit_all_scales(fcalc)[source]

Fit all scale parameters using provided F_calc.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

fit_simple(fobs, fcalc)[source]

Fit a single global scale factor analytically (least-squares).

This is the simple scaling approach:

k = sum(|F_obs||F_calc|) / sum(|F_calc|²)

Useful for rigid body refinement where only an overall scale is needed.

Parameters:
  • fobs (torch.Tensor) – Observed structure factor amplitudes.

  • fcalc (torch.Tensor) – Calculated structure factors (complex).

get_scale()[source]

Get the current overall scale factor value.

Returns the mean scale factor across all bins.

Returns:

Current scale factor (not log).

Return type:

float

rfactor(fcalc)[source]

Calculate the R-factor between observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

R-work and R-free values.

Return type:

tuple

bin_wise_rfactor(fcalc)[source]

Calculate the bin-wise R-factor between observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

  • mean_res_per_bin (torch.Tensor) – Mean resolution per bin.

  • rwork_per_bin (torch.Tensor) – R-work per bin.

  • rfree_per_bin (torch.Tensor) – R-free per bin.

setup_bin_wise_bfactor()[source]

Initialize bin-wise B-factor correction parameters.

bin_wise_bfactor_correction()[source]

Compute bin-wise B-factor correction factors.

Returns:

B-factor correction factors for each reflection.

Return type:

torch.Tensor

get_binwise_mean_intensity(fcalc)[source]

Get bin-wise mean intensities for observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

Mean observed intensity, mean calculated intensity, and mean resolution per bin.

Return type:

tuple

screen_solvent_params(fcalc, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]

Screen solvent parameters (k_sol, B_sol) using grid search.

The bulk solvent contributes primarily at low resolution. Fitting on low-resolution reflections only (fit_on_low_res_only=True) prevents high-resolution reflections from dominating the optimization and pushing B_sol too low.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • steps (int, default 15) – Number of grid points for each parameter.

  • use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily since solvent primarily contributes at low resolution.

  • low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.

  • fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.

  • low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.

refine_lbfgs(fcalc, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS optimizer.

This method optimizes the anisotropic scaling and B-factor parameters that relate calculated structure factors to observed structure factors. Uses the L-BFGS quasi-Newton optimization method for fast convergence.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • nsteps (int, default 3) – Number of LBFGS steps.

  • lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int, default 200) – Maximum iterations per line search.

  • history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.

  • verbose (bool, default True) – Print progress information.

Returns:

Dictionary with refinement metrics including steps, xray_work, xray_test, rwork, rfree.

Return type:

dict

estimate_sigma_eff(fcalc, max_inflation=2.0)[source]

Estimate per-resolution-shell effective sigmas from current residuals.

Pannu & Read / SIGMAA-style correction: detects miscalibrated experimental sigmas by comparing residual variance to the claimed variance, per resolution bin.

For each resolution bin:

D_bin = < (F_obs - k * |F_calc|)^2 > (using work set) ratio_bin = sqrt(D_bin / <sigma_F^2>) ratio_capped = clamp(ratio_bin, 1.0, max_inflation) sigma_eff = sigma_F * ratio_capped

Why the cap? At the start of refinement the model is bad, so residuals are dominated by model error (which is fixable by refining), not noise. Uncapped inflation creates a vicious cycle: bad model -> huge sigma_eff -> weak data gradient -> bad model. Capping at max_inflation (default 2.0, i.e. sigmas can grow at most 2x) prevents runaway while still correcting genuinely under-estimated sigmas.

As the model improves, residuals shrink and the ratio drops toward 1, so sigma_eff converges to the raw sigma (good calibration).

Uses the work set only so the test set doesn’t leak into sigma estimation.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex, unscaled).

  • max_inflation (float, optional) – Maximum allowed ratio sigma_eff / sigma_raw. Default 2.0.

Returns:

Per-reflection effective sigmas, shape (N,).

Return type:

torch.Tensor

forward(fcalc, use_mask=True, f_sol_override=None)[source]

Forward pass for the ScalerBase module.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors. Expected shape (N,), an additional dimension for batch is possible. N should match the full HKL size.

  • use_mask (bool, default True) – Deprecated parameter, kept for backward compatibility.

  • f_sol_override (torch.Tensor, optional) – Pre-computed raw solvent structure factors. When provided, these replace the internally-cached _f_sol_raw. The scaler’s k_sol / B_sol / phase damping is still applied. This is used by CollectionScaler to supply mixed (fraction-weighted) solvent contributions.

Returns:

Scaled structure factors of same shape as input.

Return type:

torch.Tensor

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the ScalerBase.

This includes:

  • All registered buffers and parameters (via parent class)

  • Scaler-specific metadata (nbins, etc.)

  • Solvent model state (if set)

Note: Data reference is NOT saved (managed separately).

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, default '') – Prefix for parameter names.

  • keep_vars (bool, default False) – Whether to keep variables in computational graph.

Returns:

Complete state dictionary.

Return type:

dict

load_state_dict(state_dict, strict=True)[source]

Load the ScalerBase state from a dictionary.

Note: This assumes data is already set via __init__ or set_data().

Parameters:
  • state_dict (dict) – Dictionary containing scaler state.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

save_state(path)[source]

Save the complete state of the scaler to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the scaler from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

class torchref.scaling.SolventModel(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]

Bases: DeviceMixin, DebugMixin, Module

SolventModel to compute solvent contribution to structure factors using Phenix-like approach.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    solvent = SolventModel()  # Creates empty shell
    solvent.load_state_dict(torch.load('solvent.pt'))
    
  2. Full initialization with model:

    solvent = SolventModel(model, k_solvent=0.35, b_solvent=46.0)
    
model

The atomic model for structure factor calculations.

Type:

ModelFT or None

device

Device for tensor operations.

Type:

torch.device

verbose

Verbosity level.

Type:

int

float_type

Floating point data type.

Type:

torch.dtype

solvent_radius

Probe radius in Angstroms for dilation.

Type:

float

erosion_radius

Radius in Angstroms for erosion step.

Type:

float

optimize_phase

Whether to optimize phase offset parameter.

Type:

bool

log_k_solvent

Log of solvent scattering scale factor.

Type:

torch.nn.Parameter

b_solvent

Solvent B-factor.

Type:

torch.nn.Parameter

phase_offset

Phase offset in radians.

Type:

torch.nn.Parameter or buffer

__init__(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]

Initialize SolventModel.

If model is provided, fully initializes the solvent model. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • model (ModelFT, optional) – The atomic model used for structure factor calculations (optional for empty init).

  • radius (float, default 1.1) – Probe radius in Angstroms for dilation (water radius).

  • k_solvent (float, default 1.1) – Solvent scattering scale factor.

  • b_solvent (float, default 50.0) – Solvent B-factor.

  • erosion_radius (float, default 0.9) – Radius in Angstroms for erosion step.

  • transition (float, optional) – Gaussian smoothing sigma for mask edges (default: radius/4 in voxels). Avoids ringing artifacts.

  • optimize_phase (bool, default True) – Whether to optimize phase offset parameter.

  • initial_phase_offset (float, default 0.0) – Initial phase offset in radians.

  • verbose (int, default 1) – Verbosity level.

  • float_type (torch.dtype, default torch.float32) – Floating point data type.

  • device (torch.device, default: configured device.current) – Device for tensor operations.

get_solvent_mask()[source]

Generate solvent mask following Phenix’s three-step process.

Step 1 (dilation): classify voxels around each atom as protein

(inside VdW), boundary (between VdW and VdW+solvent_radius), or bulk solvent (further out). Built in chunks over atoms so peak memory is O(atom_chunk_size × N_box_voxels) rather than O(N_atoms × N_box_voxels) — critical because for typical macromolecule + grid combinations the dense form is multi-GB.

Step 2 (symmetry expansion): transform the sparse ASU protein /

boundary voxel indices through each symop and scatter into the P1 grid masks.

Step 3 (erosion): a boundary voxel becomes solvent if any voxel

within erosion_radius of it is bulk solvent. Implemented as a single F.conv3d with a precomputed spherical kernel and circular padding — replaces the previous Python-loop + per-voxel-neighbourhood expansion that itself ran out of memory on chunks of 10^6 boundary voxels.

Returns:

Solvent mask (boolean) where True = solvent.

Return type:

torch.Tensor

update_solvent()[source]
smooth_solvent_mask()[source]
get_rec_solvent(hkl)[source]

Compute solvent structure factors.

Uses the standard crystallographic approach: compute SFs from the solvent mask. The mask represents regions where bulk solvent scattering occurs.

Parameters:

hkl (torch.Tensor) – Miller indices.

Returns:

Complex solvent structure factors.

Return type:

torch.Tensor

forward(hkl, update_fsol=False, F_protein=None)[source]

Compute solvent contribution to structure factors at given HKL.

This method is differentiable with respect to k_solvent, b_solvent, and phase_offset parameters.

The solvent model:

  1. Takes the binary solvent mask

  2. Smooths it with Gaussian filter (σ=1.5 voxels) to create soft edges

  3. Computes structure factors via FFT

  4. Applies B-factor damping: exp(-B * s²) where s = sin(θ)/λ

  5. If optimize_phase=True and F_protein provided: blends mask phases with protein phases phase_offset controls the blend: 0=use mask phases, ±π=use protein phases

  6. Scales by k_solvent

Parameters:
  • hkl (torch.Tensor) – Miller indices, shape (N, 3).

  • update_fsol (bool, default False) – Whether to update solvent structure factors.

  • F_protein (torch.Tensor, optional) – Protein structure factors, used for phase blending.

Returns:

Complex solvent structure factors, shape (N,).

Return type:

torch.Tensor

parameters()[source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
class torchref.scaling.CollectionScaler(dataset_collection, model_collection, nbins=20, verbose=1, device=device(type='cpu'))[source]

Bases: ScalerBase

Joint scaler for DatasetCollection + ModelCollection.

Shares scale parameters (log_scale, U, bin_wise_bfactor, k_sol, B_sol, phase) across all data–model pairs. Manages per-component solvent models so that the bulk-solvent contribution for a mixed model is the fraction-weighted sum of individual component solvent SFs.

Parameters:
  • dataset_collection (DatasetCollection) – Collection of reflection datasets keyed by timepoint name.

  • model_collection (ModelCollection) – Collection of mixed models keyed by timepoint name.

  • nbins (int) – Number of resolution bins.

  • verbose (int) – Verbosity level.

  • device (torch.device) – Computation device.

Examples

scaler = CollectionScaler(datasets, models, device=device)
scaler.initialize()
scaler.refine_lbfgs_joint()

# In a target: scale a mixed-model F_calc with matching solvent
f_scaled = scaler.forward_mixed(f_calc, model.fractions)
__init__(dataset_collection, model_collection, nbins=20, verbose=1, device=device(type='cpu'))[source]

Initialize ScalerBase.

If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from data (if given) or the configured default via torchref.utils.resolve_device(). An explicit value forces data onto that device.

initialize()[source]

One-shot initialization: joint initial scale, component solvents, anisotropy correction.

Returns:

Self, for method chaining.

Return type:

CollectionScaler

get_mixed_solvent_raw(fractions)[source]

Compute fraction-weighted raw solvent SFs.

f_sol_mixed = sum_i(w_i * f_sol_raw_i)

Parameters:

fractions (torch.Tensor) – Population fractions, shape (n_base_models,).

Returns:

Mixed raw solvent structure factors (complex, un-damped).

Return type:

torch.Tensor

forward_mixed(fcalc, fractions)[source]

Scale fcalc using the shared parameters and a fraction- weighted solvent contribution.

This sets _f_sol_raw to the mixed solvent and then delegates to ScalerBase.forward(), which applies k_sol / B_sol / phase damping and the overall + anisotropic scale.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • fractions (torch.Tensor) – Population fractions for the mixed model, shape (n_base_models,).

Returns:

Scaled structure factors.

Return type:

torch.Tensor

refine_lbfgs_joint(nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS against all datasets.

The closure sums the NLL across every matched dataset–model pair, so a single set of scale parameters is fitted jointly.

Parameters:
  • nsteps (int) – Number of LBFGS outer steps.

  • lr (float) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int) – Maximum line-search iterations per step.

  • history_size (int) – LBFGS history size.

  • verbose (bool) – Print progress.

Returns:

Refinement metrics (steps, rwork, rfree of dark dataset).

Return type:

dict

screen_solvent_params_joint(steps=15)[source]

Grid-search k_sol / B_sol using NLL summed across all datasets.

Parameters:

steps (int) – Grid points per parameter.

update_all_solvent()[source]

Recompute solvent masks for all component models.

Call this after structure refinement changes base-model coordinates.

invalidate_solvent_cache()[source]

Clear cached raw solvent SFs (forces recomputation on next call).

property component_solvent_models: ModuleList

Per-component SolventModel instances (read-only).

Submodules