"""
SampledMLPhaseTarget - Phase-aware ML target using reparameterized sampling.
This module provides a maximum-likelihood refinement target that accounts for
phase uncertainty derived from amplitude errors. Uses the reparameterization
trick to enable differentiable Monte Carlo estimation of expected structure
factor errors in complex space.
Key insight: Standard ML refinement treats calculated phases as exact. This
approach samples phases from a distribution whose width depends on amplitude
discrepancy, allowing gradients to naturally account for phase uncertainty.
"""
import numpy as np
import torch
from typing import TYPE_CHECKING, Dict, Tuple
from .base import Target
from .xray import XrayTarget
from torchref.utils.stats import (
VERBOSITY_STANDARD,
VERBOSITY_DETAILED,
VERBOSITY_DEBUG,
StatEntry,
stat,
)
if TYPE_CHECKING:
from torchref.io import ReflectionData
from torchref.io.datasets.collection import DatasetCollection
from torchref.model import Model, ModelFT
from torchref.scaling import Scaler
[docs]
class SampledMLPhaseTarget(XrayTarget):
"""
Phase-aware ML target using reparameterized sampling.
Computes E[|F_obs*exp(i*phi) - F_calc|^2] where phi ~ N(phi_ref, sigma_phi^2)
with sigma_phi derived from amplitude errors and discrepancy.
Uses French-Wilson posteriors for amplitude estimation and supports
both Monte Carlo sampling and analytical evaluation.
Parameters
----------
data : ReflectionData
Reference to the ReflectionData object.
model : Model or ModelFT, optional
Reference to Model object for F_calc computation.
scaler : Scaler, optional
Reference to the Scaler object.
phi_ref : torch.Tensor, optional
Reference phases (e.g., from dark state). If None, uses phi_calc.
n_samples : int, optional
Number of MC samples. Default is 32.
sigma_model_log : float, optional
Model error in log(I) space (~R_work). Default is 0.15.
use_analytical : bool, optional
Use closed-form instead of MC sampling. Default is False.
use_antithetic : bool, optional
Use antithetic sampling for variance reduction. Default is True.
use_work_set : bool, optional
If True, compute loss on work set. Default is True.
verbose : int, optional
Verbosity level. Default is 0.
Attributes
----------
name : str
Target name for LossState registration.
Examples
--------
Basic usage with model::
target = SampledMLPhaseTarget(
data=reflection_data,
model=model,
scaler=scaler,
n_samples=32,
)
loss = target() # Computes F_calc internally
With pre-computed F_calc::
target = SampledMLPhaseTarget(data=reflection_data)
loss = target(fcalc=F_calc_precomputed)
With reference phases from dark state::
target = SampledMLPhaseTarget(
data=light_data,
model=light_model,
phi_ref=torch.angle(F_dark_calc),
)
"""
name: str = "xray_sampled_ml"
[docs]
def __init__(
self,
data: "ReflectionData" = None,
model: "Model" = None,
scaler: "Scaler" = None,
phi_ref: torch.Tensor = None,
n_samples: int = 32,
sigma_model_log: float = 0.15,
use_analytical: bool = False,
use_antithetic: bool = True,
use_work_set: bool = True,
verbose: int = 0,
):
super().__init__(
data=data,
model=model,
scaler=scaler,
use_work_set=use_work_set,
verbose=verbose,
)
# Update name based on work/test set
self.name = "xray_sampled_ml_work" if use_work_set else "xray_sampled_ml_test"
# Register tunable parameters as buffers for state_dict access
self.register_buffer("_n_samples", torch.tensor(n_samples, dtype=torch.int64))
self.register_buffer("_sigma_model_log", torch.tensor(sigma_model_log))
self.register_buffer("_use_analytical", torch.tensor(use_analytical))
self.register_buffer("_use_antithetic", torch.tensor(use_antithetic))
# Reference phases (optional - if None, uses phi_calc)
if phi_ref is not None:
self.register_buffer("_phi_ref", phi_ref)
else:
self._phi_ref = None
# Cache for diagnostics (populated on forward)
self._last_diagnostics: Dict[str, torch.Tensor] = {}
# =========================================================================
# Property accessors for buffer parameters
# =========================================================================
@property
def n_samples(self) -> int:
"""Get number of MC samples."""
return self._n_samples.item()
@n_samples.setter
def n_samples(self, value: int):
"""Set number of MC samples."""
self._n_samples.fill_(value)
@property
def sigma_model_log(self) -> float:
"""Get model error in log(I) space."""
return self._sigma_model_log.item()
@sigma_model_log.setter
def sigma_model_log(self, value: float):
"""Set model error in log(I) space."""
self._sigma_model_log.fill_(value)
@property
def use_analytical(self) -> bool:
"""Get whether to use analytical form."""
return self._use_analytical.item()
@use_analytical.setter
def use_analytical(self, value: bool):
"""Set whether to use analytical form."""
self._use_analytical.fill_(value)
@property
def use_antithetic(self) -> bool:
"""Get whether to use antithetic sampling."""
return self._use_antithetic.item()
@use_antithetic.setter
def use_antithetic(self, value: bool):
"""Set whether to use antithetic sampling."""
self._use_antithetic.fill_(value)
# =========================================================================
# Core computation methods
# =========================================================================
[docs]
def french_wilson_moments(
self,
I_obs: torch.Tensor,
sigma_I: torch.Tensor,
Sigma_wilson: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute posterior mean and variance of |F_true| given I_obs.
Properly handles negative and weak intensities using numerical
integration over a grid.
Parameters
----------
I_obs : torch.Tensor
Observed intensities.
sigma_I : torch.Tensor
Intensity uncertainties.
Sigma_wilson : torch.Tensor, optional
Wilson expected intensities.
Returns
-------
F_mean : torch.Tensor
Posterior mean of |F|.
F_var : torch.Tensor
Posterior variance of |F|.
"""
device = I_obs.device
dtype = I_obs.dtype
if Sigma_wilson is None:
Sigma_wilson = torch.clamp(I_obs, min=sigma_I)
# Integration grid (30 points is sufficient for smooth posteriors)
n_points = 30
F_max = (
torch.sqrt(torch.clamp(I_obs + 4 * sigma_I, min=sigma_I))
+ torch.sqrt(Sigma_wilson)
)
t = torch.linspace(0.01, 1.0, n_points, device=device, dtype=dtype)
F_grid = t.unsqueeze(0) * F_max.unsqueeze(1) # (n_refl, n_points)
dF = F_grid[:, 1:2] - F_grid[:, :1]
# Log posterior: measurement likelihood + Wilson prior
I_obs_exp = I_obs.unsqueeze(1)
sigma_I_exp = sigma_I.unsqueeze(1)
Sigma_exp = Sigma_wilson.unsqueeze(1)
log_post = (
-(I_obs_exp - F_grid**2) ** 2 / (2 * sigma_I_exp**2) # Gaussian in I
+ torch.log(F_grid + 1e-10) # Jacobian
- F_grid**2 / Sigma_exp # Wilson prior
)
log_post = log_post - log_post.max(dim=1, keepdim=True)[0]
post = torch.exp(log_post)
post = post / (post.sum(dim=1, keepdim=True) * dF + 1e-10)
# Moments
F_mean = (post * F_grid * dF).sum(dim=1)
F_sq_mean = (post * F_grid**2 * dF).sum(dim=1)
F_var = torch.clamp(F_sq_mean - F_mean**2, min=1e-6)
return F_mean, F_var
[docs]
def compute_sigma_phi(
self,
F_obs: torch.Tensor,
sigma_F_obs: torch.Tensor,
F_calc_amp: torch.Tensor,
) -> torch.Tensor:
"""
Compute phase uncertainty from amplitude uncertainties and discrepancy.
The phase uncertainty has three components:
1. Measurement uncertainty: sigma_F_obs / |F_obs|
2. Model uncertainty: sigma_model_log (multiplicative)
3. Excess from amplitude discrepancy beyond expected
Parameters
----------
F_obs : torch.Tensor
Observed amplitudes (or French-Wilson means).
sigma_F_obs : torch.Tensor
Amplitude uncertainties.
F_calc_amp : torch.Tensor
Calculated amplitudes |F_calc|.
Returns
-------
sigma_phi : torch.Tensor
Phase uncertainty in radians.
"""
# Model error (lognormal -> multiplicative in |F|)
sigma_F_model = self.sigma_model_log * F_obs
# Total amplitude uncertainty
sigma_F_total = torch.sqrt(sigma_F_obs**2 + sigma_F_model**2)
# Base phase uncertainty from amplitude uncertainties
sigma_phi_meas = sigma_F_obs / (F_obs + 1e-6)
sigma_phi_model = sigma_F_model / (F_calc_amp + 1e-6)
# Excess from amplitude discrepancy
amplitude_discrepancy = torch.abs(F_obs - F_calc_amp)
expected_discrepancy = sigma_F_total
excess = torch.clamp(amplitude_discrepancy - expected_discrepancy, min=0)
sigma_phi_excess = excess / (torch.minimum(F_obs, F_calc_amp) + 1e-6)
# Combined
sigma_phi = torch.sqrt(
sigma_phi_meas**2 + sigma_phi_model**2 + sigma_phi_excess**2
)
sigma_phi = torch.clamp(sigma_phi, min=0.01, max=2.0)
return sigma_phi
[docs]
def forward(
self,
fcalc: torch.Tensor = None,
recalc: bool = True,
) -> torch.Tensor:
"""
Compute phase-aware ML loss.
Parameters
----------
fcalc : torch.Tensor, optional
Pre-computed complex structure factors. If provided, uses these
instead of computing from model.
recalc : bool, optional
Force recalculation if True. Default is True.
Returns
-------
torch.Tensor
Mean weighted loss value.
"""
# Get data using XrayTarget.get_data() - handles work/test selection
F_obs, F_calc, sigma_F_obs, centric_flags = self.get_data(fcalc=fcalc)
device = F_obs.device
n_refl = F_obs.shape[0]
F_calc_amp = torch.abs(F_calc)
phi_calc = torch.angle(F_calc)
# Use reference phases if provided, otherwise use calculated phases
if self._phi_ref is not None:
# Need to select same reflections as F_obs
# This assumes phi_ref was computed on full dataset
hkl, _, _, rfree_mask = self._data()
# Note: rfree_mask may be int32 (0/1), must convert to bool for proper masking
rfree_bool = rfree_mask.bool()
if self.use_work_set:
phi_ref = self._phi_ref[rfree_bool]
else:
phi_ref = self._phi_ref[~rfree_bool]
else:
phi_ref = phi_calc
# Get intensity data for French-Wilson
# Convert F_obs to I_obs (approximation for French-Wilson input)
I_obs = F_obs**2
sigma_I = 2 * F_obs * sigma_F_obs # Error propagation
# French-Wilson posterior moments
F_mean, F_var = self.french_wilson_moments(I_obs, sigma_I)
sigma_F_meas = torch.sqrt(F_var)
# Phase uncertainty
sigma_phi = self.compute_sigma_phi(F_mean, sigma_F_meas, F_calc_amp)
# Effective figure of merit
m_eff = torch.exp(-sigma_phi**2 / 2)
if self.use_analytical:
# Analytical form (faster, no sampling noise)
phase_shift = phi_ref - phi_calc
expected_sq_error = (
F_mean**2
+ F_calc_amp**2
- 2 * F_mean * F_calc_amp * torch.cos(phase_shift) * m_eff
)
else:
# Monte Carlo with reparameterization
n_samples = self.n_samples
if self.use_antithetic:
eps = torch.randn(n_refl, n_samples // 2, device=device)
eps = torch.cat([eps, -eps], dim=1)
else:
eps = torch.randn(n_refl, n_samples, device=device)
# Sample phases: phi = phi_ref + sigma_phi * eps
phi_samples = phi_ref.unsqueeze(1) + sigma_phi.unsqueeze(1) * eps
# Construct complex F_obs samples
F_obs_real = F_mean.unsqueeze(1) * torch.cos(phi_samples)
F_obs_imag = F_mean.unsqueeze(1) * torch.sin(phi_samples)
# Complex error: F_obs_sample - F_calc
error_real = F_obs_real - F_calc.real.unsqueeze(1)
error_imag = F_obs_imag - F_calc.imag.unsqueeze(1)
# |error|^2 for each sample
sq_error_samples = error_real**2 + error_imag**2
# Monte Carlo estimate
expected_sq_error = sq_error_samples.mean(dim=1)
# Model error for weighting
sigma_F_model = self.sigma_model_log * F_mean
sigma_F_total = torch.sqrt(sigma_F_meas**2 + sigma_F_model**2)
# Weighted loss
weights = 1.0 / (sigma_F_total**2 + 1e-6)
loss = (weights * expected_sq_error).mean()
# Store diagnostics
self._last_diagnostics = {
"F_mean": F_mean.detach(),
"sigma_F_total": sigma_F_total.detach(),
"sigma_phi": sigma_phi.detach(),
"m_eff": m_eff.detach(),
"expected_sq_error": expected_sq_error.detach(),
"weights": weights.detach(),
}
return loss
[docs]
def stats(self, fcalc: torch.Tensor = None) -> Dict[str, StatEntry]:
"""
Get statistics for this target.
Parameters
----------
fcalc : torch.Tensor, optional
Pre-computed structure factors.
Returns
-------
dict
Statistics dict with StatEntry values containing verbosity levels.
"""
# Compute loss to populate diagnostics
with torch.no_grad():
loss = self.forward(fcalc=fcalc)
diag = self._last_diagnostics
# R-factor computation
F_obs, F_calc, _, _ = self.get_data(fcalc=fcalc)
F_calc_amp = torch.abs(F_calc)
r_factor = (
torch.abs(F_obs - F_calc_amp).sum() / (F_obs.sum() + 1e-8)
).item()
return {
"loss": stat(loss.item(), VERBOSITY_STANDARD),
"n": stat(len(F_obs), VERBOSITY_DEBUG),
"r_factor": stat(r_factor, VERBOSITY_STANDARD),
"mean_m_eff": stat(diag["m_eff"].mean().item(), VERBOSITY_STANDARD),
"mean_sigma_phi": stat(diag["sigma_phi"].mean().item(), VERBOSITY_DETAILED),
"min_m_eff": stat(diag["m_eff"].min().item(), VERBOSITY_DETAILED),
"max_sigma_phi": stat(diag["sigma_phi"].max().item(), VERBOSITY_DETAILED),
"n_samples": stat(self.n_samples, VERBOSITY_DEBUG),
"sigma_model_log": stat(self.sigma_model_log, VERBOSITY_DEBUG),
}
def __repr__(self) -> str:
mode = "analytical" if self.use_analytical else f"MC({self.n_samples})"
return f"SampledMLPhaseTarget(mode={mode}, sigma_model={self.sigma_model_log:.3f})"
[docs]
class SampledMLDifferenceTarget(Target):
"""
Phase-aware difference target for two-dataset refinement.
Uses dark state phases as reference, with phase uncertainty
informed by amplitude changes between states. Jointly refines
against both dark and light datasets.
Parameters
----------
dataset_collection : DatasetCollection
Collection containing 'dark' and 'light' datasets.
model_light : ModelFT or MixedModel
Model for the light/excited state.
model_dark : ModelFT
Model for the dark/ground state.
scaler_light : Scaler, optional
Scaler for light state F_calc.
scaler_dark : Scaler, optional
Scaler for dark state F_calc.
n_samples : int, optional
Number of MC samples. Default is 32.
sigma_model_log : float, optional
Model error in log(I) space. Default is 0.15.
use_work_set : bool, optional
If True, compute loss on work set. Default is True.
verbose : int, optional
Verbosity level. Default is 0.
Examples
--------
Basic usage::
target = SampledMLDifferenceTarget(
dataset_collection=collection,
model_light=mixed_model,
model_dark=model_dark,
n_samples=32,
)
loss = target()
"""
name: str = "sampled_ml_difference"
[docs]
def __init__(
self,
dataset_collection: "DatasetCollection",
model_light: "ModelFT" = None,
model_dark: "ModelFT" = None,
scaler_light: "Scaler" = None,
scaler_dark: "Scaler" = None,
n_samples: int = 32,
sigma_model_log: float = 0.15,
use_work_set: bool = True,
verbose: int = 0,
):
super().__init__(verbose=verbose)
# Validation
if "dark" not in dataset_collection:
raise ValueError("DatasetCollection must contain 'dark' dataset")
if "light" not in dataset_collection:
raise ValueError("DatasetCollection must contain 'light' dataset")
self._dataset_collection = dataset_collection
self._data_dark = dataset_collection["dark"]
self._data_light = dataset_collection["light"]
# Register models as submodules
self.add_module("_model_light", model_light)
self.add_module("_model_dark", model_dark)
self.add_module("_scaler_light", scaler_light)
self.add_module("_scaler_dark", scaler_dark)
# Tunable parameters as buffers
self.register_buffer("_n_samples", torch.tensor(n_samples, dtype=torch.int64))
self.register_buffer("_sigma_model_log", torch.tensor(sigma_model_log))
self.use_work_set = use_work_set
# Cache for diagnostics
self._last_diagnostics: Dict[str, torch.Tensor] = {}
# Setup data buffers
self._setup_data()
def _setup_data(self):
"""Setup observed data and masks."""
_, F_light, sigma_light, rfree_light = self._data_light()
_, F_dark, sigma_dark, rfree_dark = self._data_dark()
# Handle MaskedTensor
if hasattr(F_light, "get_data"):
F_light = F_light.get_data()
sigma_light = sigma_light.get_data()
if hasattr(F_dark, "get_data"):
F_dark = F_dark.get_data()
sigma_dark = sigma_dark.get_data()
self.register_buffer("_F_obs_light", F_light)
self.register_buffer("_F_obs_dark", F_dark)
self.register_buffer("_sigma_light", sigma_light)
self.register_buffer("_sigma_dark", sigma_dark)
# Work/test set mask
if self.use_work_set:
mask = rfree_light.bool() & rfree_dark.bool()
else:
mask = ~rfree_light.bool() & ~rfree_dark.bool()
self.register_buffer("_mask", mask)
@property
def hkl(self) -> torch.Tensor:
"""Common HKL indices."""
return self._dataset_collection.hkl
@property
def n_samples(self) -> int:
"""Get number of MC samples."""
return self._n_samples.item()
@n_samples.setter
def n_samples(self, value: int):
"""Set number of MC samples."""
self._n_samples.fill_(value)
@property
def sigma_model_log(self) -> float:
"""Get model error in log(I) space."""
return self._sigma_model_log.item()
@sigma_model_log.setter
def sigma_model_log(self, value: float):
"""Set model error in log(I) space."""
self._sigma_model_log.fill_(value)
def _compute_sigma_phi(
self,
F_obs: torch.Tensor,
sigma_F_obs: torch.Tensor,
F_calc_amp: torch.Tensor,
) -> torch.Tensor:
"""
Compute phase uncertainty from amplitude uncertainties and discrepancy.
Parameters
----------
F_obs : torch.Tensor
Observed amplitudes.
sigma_F_obs : torch.Tensor
Amplitude uncertainties.
F_calc_amp : torch.Tensor
Calculated amplitudes |F_calc|.
Returns
-------
sigma_phi : torch.Tensor
Phase uncertainty in radians.
"""
sigma_F_model = self.sigma_model_log * F_obs
sigma_F_total = torch.sqrt(sigma_F_obs**2 + sigma_F_model**2)
sigma_phi_meas = sigma_F_obs / (F_obs + 1e-6)
sigma_phi_model = sigma_F_model / (F_calc_amp + 1e-6)
amplitude_discrepancy = torch.abs(F_obs - F_calc_amp)
excess = torch.clamp(amplitude_discrepancy - sigma_F_total, min=0)
sigma_phi_excess = excess / (torch.minimum(F_obs, F_calc_amp) + 1e-6)
sigma_phi = torch.sqrt(
sigma_phi_meas**2 + sigma_phi_model**2 + sigma_phi_excess**2
)
return torch.clamp(sigma_phi, min=0.01, max=2.0)
[docs]
def forward(
self,
fcalc_light: torch.Tensor = None,
fcalc_dark: torch.Tensor = None,
recalc: bool = True,
) -> torch.Tensor:
"""
Compute phase-aware difference loss.
Jointly refines against dark and light datasets using dark phases
as reference. Phase uncertainty increases for reflections with
large amplitude changes between states.
Parameters
----------
fcalc_light : torch.Tensor, optional
Pre-computed light state structure factors.
fcalc_dark : torch.Tensor, optional
Pre-computed dark state structure factors.
recalc : bool, optional
Force recalculation if True. Default is True.
Returns
-------
torch.Tensor
Combined loss for both datasets.
"""
hkl = self.hkl
device = hkl.device
# Get F_calc for light
if fcalc_light is None:
if self._model_light is None:
raise RuntimeError("No model_light set")
fcalc_light = self._model_light(hkl, recalc=recalc)
if self._scaler_light is not None:
fcalc_light = self._scaler_light(fcalc_light)
# Get F_calc for dark
if fcalc_dark is None:
if self._model_dark is None:
raise RuntimeError("No model_dark set")
fcalc_dark = self._model_dark(hkl, recalc=recalc)
if self._scaler_dark is not None:
fcalc_dark = self._scaler_dark(fcalc_dark)
# Reference phases from dark state (detached to prevent spurious gradients)
phi_dark = torch.angle(fcalc_dark).detach()
# Apply mask
mask = self._mask
F_dark_obs = self._F_obs_dark[mask]
F_light_obs = self._F_obs_light[mask]
sigma_dark = self._sigma_dark[mask]
sigma_light = self._sigma_light[mask]
phi_dark = phi_dark[mask]
F_calc_dark = fcalc_dark[mask]
F_calc_light = fcalc_light[mask]
n_refl = F_dark_obs.shape[0]
# Compute phase uncertainties for dark dataset
sigma_F_model_dark = self.sigma_model_log * F_dark_obs
sigma_F_dark_total = torch.sqrt(sigma_dark**2 + sigma_F_model_dark**2)
sigma_phi_dark = self._compute_sigma_phi(
F_dark_obs, sigma_dark, torch.abs(F_calc_dark)
)
# Additional uncertainty for light from amplitude change
delta_F_obs = torch.abs(F_light_obs - F_dark_obs)
delta_F_calc = torch.abs(torch.abs(F_calc_light) - torch.abs(F_calc_dark))
delta_discrepancy = torch.abs(delta_F_obs - delta_F_calc)
sigma_phi_change = delta_discrepancy / (F_light_obs + 1e-6)
sigma_phi_light_base = self._compute_sigma_phi(
F_light_obs, sigma_light, torch.abs(F_calc_light)
)
sigma_phi_light = torch.sqrt(sigma_phi_light_base**2 + sigma_phi_change**2)
sigma_phi_light = torch.clamp(sigma_phi_light, min=0.01, max=2.0)
# Reparameterized sampling with antithetic variates
n_samples = self.n_samples
eps = torch.randn(n_refl, n_samples // 2, device=device)
eps = torch.cat([eps, -eps], dim=1)
# Dark samples
phi_samples_dark = phi_dark.unsqueeze(1) + sigma_phi_dark.unsqueeze(1) * eps
F_dark_samples = F_dark_obs.unsqueeze(1) * torch.exp(1j * phi_samples_dark)
# Light samples (using dark phases as reference)
phi_samples_light = phi_dark.unsqueeze(1) + sigma_phi_light.unsqueeze(1) * eps
F_light_samples = F_light_obs.unsqueeze(1) * torch.exp(1j * phi_samples_light)
# Losses
sq_error_dark = torch.abs(F_dark_samples - F_calc_dark.unsqueeze(1)) ** 2
sq_error_light = torch.abs(F_light_samples - F_calc_light.unsqueeze(1)) ** 2
expected_error_dark = sq_error_dark.mean(dim=1)
expected_error_light = sq_error_light.mean(dim=1)
# Weights
sigma_F_model_light = self.sigma_model_log * F_light_obs
sigma_light_total = torch.sqrt(sigma_light**2 + sigma_F_model_light**2)
weights_dark = 1.0 / (sigma_F_dark_total**2 + 1e-6)
weights_light = 1.0 / (sigma_light_total**2 + 1e-6)
loss = (weights_dark * expected_error_dark).mean() + (
weights_light * expected_error_light
).mean()
# Store diagnostics
self._last_diagnostics = {
"sigma_phi_dark": sigma_phi_dark.detach(),
"sigma_phi_light": sigma_phi_light.detach(),
"m_dark": torch.exp(-sigma_phi_dark**2 / 2).detach(),
"m_light": torch.exp(-sigma_phi_light**2 / 2).detach(),
"delta_F_obs": delta_F_obs.detach(),
"delta_discrepancy": delta_discrepancy.detach(),
}
return loss
[docs]
def stats(
self,
fcalc_light: torch.Tensor = None,
fcalc_dark: torch.Tensor = None,
) -> Dict[str, StatEntry]:
"""
Get statistics for difference refinement.
Parameters
----------
fcalc_light : torch.Tensor, optional
Pre-computed light state structure factors.
fcalc_dark : torch.Tensor, optional
Pre-computed dark state structure factors.
Returns
-------
dict
Statistics dict with StatEntry values.
"""
with torch.no_grad():
loss = self.forward(fcalc_light=fcalc_light, fcalc_dark=fcalc_dark)
diag = self._last_diagnostics
return {
"loss": stat(loss.item(), VERBOSITY_STANDARD),
"n": stat(self._mask.sum().item(), VERBOSITY_DEBUG),
"mean_m_dark": stat(diag["m_dark"].mean().item(), VERBOSITY_STANDARD),
"mean_m_light": stat(diag["m_light"].mean().item(), VERBOSITY_STANDARD),
"mean_sigma_phi_dark": stat(
diag["sigma_phi_dark"].mean().item(), VERBOSITY_DETAILED
),
"mean_sigma_phi_light": stat(
diag["sigma_phi_light"].mean().item(), VERBOSITY_DETAILED
),
"mean_delta_F_obs": stat(
diag["delta_F_obs"].mean().item(), VERBOSITY_DETAILED
),
"mean_delta_discrepancy": stat(
diag["delta_discrepancy"].mean().item(), VERBOSITY_DETAILED
),
}
def __repr__(self) -> str:
return (
f"SampledMLDifferenceTarget(n_samples={self.n_samples}, "
f"sigma_model={self.sigma_model_log:.3f})"
)
# =============================================================================
# Factory functions
# =============================================================================
[docs]
def create_sampled_ml_target(
data: "ReflectionData" = None,
model: "Model" = None,
scaler: "Scaler" = None,
phi_ref: torch.Tensor = None,
n_samples: int = 32,
sigma_model_log: float = 0.15,
use_analytical: bool = False,
use_work_set: bool = True,
verbose: int = 0,
) -> SampledMLPhaseTarget:
"""
Factory function to create SampledMLPhaseTarget.
See SampledMLPhaseTarget for parameter documentation.
Returns
-------
SampledMLPhaseTarget
Configured target instance.
"""
return SampledMLPhaseTarget(
data=data,
model=model,
scaler=scaler,
phi_ref=phi_ref,
n_samples=n_samples,
sigma_model_log=sigma_model_log,
use_analytical=use_analytical,
use_work_set=use_work_set,
verbose=verbose,
)
[docs]
def create_sampled_ml_difference_target(
dataset_collection: "DatasetCollection",
model_light: "ModelFT" = None,
model_dark: "ModelFT" = None,
scaler_light: "Scaler" = None,
scaler_dark: "Scaler" = None,
n_samples: int = 32,
sigma_model_log: float = 0.15,
use_work_set: bool = True,
verbose: int = 0,
) -> SampledMLDifferenceTarget:
"""
Factory function to create SampledMLDifferenceTarget.
See SampledMLDifferenceTarget for parameter documentation.
Returns
-------
SampledMLDifferenceTarget
Configured target instance.
"""
return SampledMLDifferenceTarget(
dataset_collection=dataset_collection,
model_light=model_light,
model_dark=model_dark,
scaler_light=scaler_light,
scaler_dark=scaler_dark,
n_samples=n_samples,
sigma_model_log=sigma_model_log,
use_work_set=use_work_set,
verbose=verbose,
)
__all__ = [
"SampledMLPhaseTarget",
"SampledMLDifferenceTarget",
"create_sampled_ml_target",
"create_sampled_ml_difference_target",
]