torchref.scaling.collection_scaler module

Joint scaler for multi-dataset / multi-model kinetic refinement.

CollectionScaler extends ScalerBase to operate on paired DatasetCollection + ModelCollection instances. A single set of scale parameters (log_scale, U, k_sol, B_sol, phase) is shared across all data–model pairs, preventing artificial scale differences that corrupt the difference signal in time-resolved refinement.

Per-component solvent models are created for each base model in the ModelCollection. For a mixed model at any timepoint the solvent contribution is the linear combination of individual component solvent structure factors, weighted by the same population fractions as the structural models.

class torchref.scaling.collection_scaler.CollectionScaler(dataset_collection, model_collection, nbins=20, verbose=1, device=device(type='cpu'))[source]

Bases: ScalerBase

Joint scaler for DatasetCollection + ModelCollection.

Shares scale parameters (log_scale, U, bin_wise_bfactor, k_sol, B_sol, phase) across all data–model pairs. Manages per-component solvent models so that the bulk-solvent contribution for a mixed model is the fraction-weighted sum of individual component solvent SFs.

Parameters:
  • dataset_collection (DatasetCollection) – Collection of reflection datasets keyed by timepoint name.

  • model_collection (ModelCollection) – Collection of mixed models keyed by timepoint name.

  • nbins (int) – Number of resolution bins.

  • verbose (int) – Verbosity level.

  • device (torch.device) – Computation device.

Examples

scaler = CollectionScaler(datasets, models, device=device)
scaler.initialize()
scaler.refine_lbfgs_joint()

# In a target: scale a mixed-model F_calc with matching solvent
f_scaled = scaler.forward_mixed(f_calc, model.fractions)
__init__(dataset_collection, model_collection, nbins=20, verbose=1, device=device(type='cpu'))[source]

Initialize ScalerBase.

If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from data (if given) or the configured default via torchref.utils.resolve_device(). An explicit value forces data onto that device.

initialize()[source]

One-shot initialization: joint initial scale, component solvents, anisotropy correction.

Returns:

Self, for method chaining.

Return type:

CollectionScaler

get_mixed_solvent_raw(fractions)[source]

Compute fraction-weighted raw solvent SFs.

f_sol_mixed = sum_i(w_i * f_sol_raw_i)

Parameters:

fractions (torch.Tensor) – Population fractions, shape (n_base_models,).

Returns:

Mixed raw solvent structure factors (complex, un-damped).

Return type:

torch.Tensor

forward_mixed(fcalc, fractions)[source]

Scale fcalc using the shared parameters and a fraction- weighted solvent contribution.

This sets _f_sol_raw to the mixed solvent and then delegates to ScalerBase.forward(), which applies k_sol / B_sol / phase damping and the overall + anisotropic scale.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • fractions (torch.Tensor) – Population fractions for the mixed model, shape (n_base_models,).

Returns:

Scaled structure factors.

Return type:

torch.Tensor

refine_lbfgs_joint(nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS against all datasets.

The closure sums the NLL across every matched dataset–model pair, so a single set of scale parameters is fitted jointly.

Parameters:
  • nsteps (int) – Number of LBFGS outer steps.

  • lr (float) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int) – Maximum line-search iterations per step.

  • history_size (int) – LBFGS history size.

  • verbose (bool) – Print progress.

Returns:

Refinement metrics (steps, rwork, rfree of dark dataset).

Return type:

dict

screen_solvent_params_joint(steps=15)[source]

Grid-search k_sol / B_sol using NLL summed across all datasets.

Parameters:

steps (int) – Grid points per parameter.

update_all_solvent()[source]

Recompute solvent masks for all component models.

Call this after structure refinement changes base-model coordinates.

invalidate_solvent_cache()[source]

Clear cached raw solvent SFs (forces recomputation on next call).

property component_solvent_models: ModuleList

Per-component SolventModel instances (read-only).