torchref.refinement.targets.sampled_ml_phase_target module
SampledMLPhaseTarget - Phase-aware ML target using reparameterized sampling.
This module provides a maximum-likelihood refinement target that accounts for phase uncertainty derived from amplitude errors. Uses the reparameterization trick to enable differentiable Monte Carlo estimation of expected structure factor errors in complex space.
Key insight: Standard ML refinement treats calculated phases as exact. This approach samples phases from a distribution whose width depends on amplitude discrepancy, allowing gradients to naturally account for phase uncertainty.
- class torchref.refinement.targets.sampled_ml_phase_target.SampledMLPhaseTarget(data=None, model=None, scaler=None, phi_ref=None, n_samples=32, sigma_model_log=0.15, use_analytical=False, use_antithetic=True, use_work_set=True, verbose=0)[source]
Bases:
XrayTargetPhase-aware ML target using reparameterized sampling.
Computes E[|F_obs*exp(i*phi) - F_calc|^2] where phi ~ N(phi_ref, sigma_phi^2) with sigma_phi derived from amplitude errors and discrepancy.
Uses French-Wilson posteriors for amplitude estimation and supports both Monte Carlo sampling and analytical evaluation.
- Parameters:
data (ReflectionData) – Reference to the ReflectionData object.
model (Model or ModelFT, optional) – Reference to Model object for F_calc computation.
scaler (Scaler, optional) – Reference to the Scaler object.
phi_ref (torch.Tensor, optional) – Reference phases (e.g., from dark state). If None, uses phi_calc.
n_samples (int, optional) – Number of MC samples. Default is 32.
sigma_model_log (float, optional) – Model error in log(I) space (~R_work). Default is 0.15.
use_analytical (bool, optional) – Use closed-form instead of MC sampling. Default is False.
use_antithetic (bool, optional) – Use antithetic sampling for variance reduction. Default is True.
use_work_set (bool, optional) – If True, compute loss on work set. Default is True.
verbose (int, optional) – Verbosity level. Default is 0.
Examples
Basic usage with model:
target = SampledMLPhaseTarget( data=reflection_data, model=model, scaler=scaler, n_samples=32, ) loss = target() # Computes F_calc internally
With pre-computed F_calc:
target = SampledMLPhaseTarget(data=reflection_data) loss = target(fcalc=F_calc_precomputed)
With reference phases from dark state:
target = SampledMLPhaseTarget( data=light_data, model=light_model, phi_ref=torch.angle(F_dark_calc), )
- __init__(data=None, model=None, scaler=None, phi_ref=None, n_samples=32, sigma_model_log=0.15, use_analytical=False, use_antithetic=True, use_work_set=True, verbose=0)[source]
Initialize X-ray target.
- Parameters:
data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().
model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, fcalc must be provided to forward().
scaler (Scaler, optional) – Reference to the Scaler object.
use_work_set (bool, optional) – If True, compute loss on work set; if False, on test set. Default is True.
sigma_mode (str, optional) –
Which sigma to use in the likelihood. Options:
'raw'(default): use the raw experimental sigmas from the data file. Empirically gives the best Rfree across the mid-resolution regime (1.5-3.0 A) when paired with appropriate group weights.'effective': use per-shell effective sigmas estimated from scaling residuals (capped SIGMAA-style correction). Opt-in for high-resolution refinement (< 1.5 A) or datasets with known sigma miscalibration. Note:Scaler.estimate_sigma_effis always called so the estimates are available regardless of which mode the target uses.
verbose (int, optional) – Verbosity level. Default is 0.
- french_wilson_moments(I_obs, sigma_I, Sigma_wilson=None)[source]
Compute posterior mean and variance of |F_true| given I_obs.
Properly handles negative and weak intensities using numerical integration over a grid.
- Parameters:
I_obs (torch.Tensor) – Observed intensities.
sigma_I (torch.Tensor) – Intensity uncertainties.
Sigma_wilson (torch.Tensor, optional) – Wilson expected intensities.
- Returns:
- Return type:
- compute_sigma_phi(F_obs, sigma_F_obs, F_calc_amp)[source]
Compute phase uncertainty from amplitude uncertainties and discrepancy.
The phase uncertainty has three components: 1. Measurement uncertainty: sigma_F_obs / |F_obs| 2. Model uncertainty: sigma_model_log (multiplicative) 3. Excess from amplitude discrepancy beyond expected
- Parameters:
F_obs (torch.Tensor) – Observed amplitudes (or French-Wilson means).
sigma_F_obs (torch.Tensor) – Amplitude uncertainties.
F_calc_amp (torch.Tensor) – Calculated amplitudes |F_calc|.
- Returns:
sigma_phi – Phase uncertainty in radians.
- Return type:
- forward(fcalc=None, recalc=True)[source]
Compute phase-aware ML loss.
- Parameters:
fcalc (torch.Tensor, optional) – Pre-computed complex structure factors. If provided, uses these instead of computing from model.
recalc (bool, optional) – Force recalculation if True. Default is True.
- Returns:
Mean weighted loss value.
- Return type:
- stats(fcalc=None)[source]
Get statistics for this target.
- Parameters:
fcalc (torch.Tensor, optional) – Pre-computed structure factors.
- Returns:
Statistics dict with StatEntry values containing verbosity levels.
- Return type:
- class torchref.refinement.targets.sampled_ml_phase_target.SampledMLDifferenceTarget(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, n_samples=32, sigma_model_log=0.15, use_work_set=True, verbose=0)[source]
Bases:
TargetPhase-aware difference target for two-dataset refinement.
Uses dark state phases as reference, with phase uncertainty informed by amplitude changes between states. Jointly refines against both dark and light datasets.
- Parameters:
dataset_collection (DatasetCollection) – Collection containing ‘dark’ and ‘light’ datasets.
model_light (ModelFT or MixedModel) – Model for the light/excited state.
model_dark (ModelFT) – Model for the dark/ground state.
scaler_light (Scaler, optional) – Scaler for light state F_calc.
scaler_dark (Scaler, optional) – Scaler for dark state F_calc.
n_samples (int, optional) – Number of MC samples. Default is 32.
sigma_model_log (float, optional) – Model error in log(I) space. Default is 0.15.
use_work_set (bool, optional) – If True, compute loss on work set. Default is True.
verbose (int, optional) – Verbosity level. Default is 0.
Examples
Basic usage:
target = SampledMLDifferenceTarget( dataset_collection=collection, model_light=mixed_model, model_dark=model_dark, n_samples=32, ) loss = target()
- __init__(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, n_samples=32, sigma_model_log=0.15, use_work_set=True, verbose=0)[source]
Initialize target.
- Parameters:
verbose (int, optional) – Verbosity level. Default is 0.
- forward(fcalc_light=None, fcalc_dark=None, recalc=True)[source]
Compute phase-aware difference loss.
Jointly refines against dark and light datasets using dark phases as reference. Phase uncertainty increases for reflections with large amplitude changes between states.
- Parameters:
fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.
fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.
recalc (bool, optional) – Force recalculation if True. Default is True.
- Returns:
Combined loss for both datasets.
- Return type:
- stats(fcalc_light=None, fcalc_dark=None)[source]
Get statistics for difference refinement.
- Parameters:
fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.
fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.
- Returns:
Statistics dict with StatEntry values.
- Return type:
- torchref.refinement.targets.sampled_ml_phase_target.create_sampled_ml_target(data=None, model=None, scaler=None, phi_ref=None, n_samples=32, sigma_model_log=0.15, use_analytical=False, use_work_set=True, verbose=0)[source]
Factory function to create SampledMLPhaseTarget.
See SampledMLPhaseTarget for parameter documentation.
- Returns:
Configured target instance.
- Return type:
- torchref.refinement.targets.sampled_ml_phase_target.create_sampled_ml_difference_target(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, n_samples=32, sigma_model_log=0.15, use_work_set=True, verbose=0)[source]
Factory function to create SampledMLDifferenceTarget.
See SampledMLDifferenceTarget for parameter documentation.
- Returns:
Configured target instance.
- Return type: