torchref.scaling.collection_scaler module
Joint scaler for multi-dataset / multi-model kinetic refinement.
CollectionScaler extends ScalerBase to operate on paired
DatasetCollection + ModelCollection instances. A single set of
scale parameters (log_scale, U, k_sol, B_sol, phase) is shared across
all data–model pairs, preventing artificial scale differences that
corrupt the difference signal in time-resolved refinement.
Per-component solvent models are created for each base model in the
ModelCollection. For a mixed model at any timepoint the solvent
contribution is the linear combination of individual component solvent
structure factors, weighted by the same population fractions as the
structural models.
- class torchref.scaling.collection_scaler.CollectionScaler(dataset_collection, model_collection, nbins=20, verbose=1, device=device(type='cpu'))[source]
Bases:
ScalerBaseJoint scaler for DatasetCollection + ModelCollection.
Shares scale parameters (log_scale, U, bin_wise_bfactor, k_sol, B_sol, phase) across all data–model pairs. Manages per-component solvent models so that the bulk-solvent contribution for a mixed model is the fraction-weighted sum of individual component solvent SFs.
- Parameters:
dataset_collection (DatasetCollection) – Collection of reflection datasets keyed by timepoint name.
model_collection (ModelCollection) – Collection of mixed models keyed by timepoint name.
nbins (int) – Number of resolution bins.
verbose (int) – Verbosity level.
device (torch.device) – Computation device.
Examples
scaler = CollectionScaler(datasets, models, device=device) scaler.initialize() scaler.refine_lbfgs_joint() # In a target: scale a mixed-model F_calc with matching solvent f_scaled = scaler.forward_mixed(f_calc, model.fractions)
- __init__(dataset_collection, model_collection, nbins=20, verbose=1, device=device(type='cpu'))[source]
Initialize ScalerBase.
If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, optional) – Computation device. If
None, derived fromdata(if given) or the configured default viatorchref.utils.resolve_device(). An explicit value forcesdataonto that device.
- initialize()[source]
One-shot initialization: joint initial scale, component solvents, anisotropy correction.
- Returns:
Self, for method chaining.
- Return type:
- get_mixed_solvent_raw(fractions)[source]
Compute fraction-weighted raw solvent SFs.
f_sol_mixed = sum_i(w_i * f_sol_raw_i)- Parameters:
fractions (torch.Tensor) – Population fractions, shape
(n_base_models,).- Returns:
Mixed raw solvent structure factors (complex, un-damped).
- Return type:
- forward_mixed(fcalc, fractions)[source]
Scale fcalc using the shared parameters and a fraction- weighted solvent contribution.
This sets
_f_sol_rawto the mixed solvent and then delegates toScalerBase.forward(), which applies k_sol / B_sol / phase damping and the overall + anisotropic scale.- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
fractions (torch.Tensor) – Population fractions for the mixed model, shape
(n_base_models,).
- Returns:
Scaled structure factors.
- Return type:
- refine_lbfgs_joint(nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]
Refine scale parameters using LBFGS against all datasets.
The closure sums the NLL across every matched dataset–model pair, so a single set of scale parameters is fitted jointly.
- Parameters:
- Returns:
Refinement metrics (steps, rwork, rfree of dark dataset).
- Return type:
- screen_solvent_params_joint(steps=15)[source]
Grid-search k_sol / B_sol using NLL summed across all datasets.
- Parameters:
steps (int) – Grid points per parameter.
- update_all_solvent()[source]
Recompute solvent masks for all component models.
Call this after structure refinement changes base-model coordinates.
- invalidate_solvent_cache()[source]
Clear cached raw solvent SFs (forces recomputation on next call).
- property component_solvent_models: ModuleList
Per-component SolventModel instances (read-only).