torchref.refinement.targets.base module
Target Functions for Crystallographic Refinement
This module provides target (loss) functions for crystallographic refinement. Each target is instantiated once with a reference to the refinement object, then evaluated on each iteration by calling the target.
Target Types: - X-ray targets: Least Squares, Maximum Likelihood, Gaussian NLL - Geometry restraint targets: Bonds, Angles, Torsions - ADP restraint targets: Similarity (SIMU), Rigid Bond (DELU)
LossState Integration: - Targets can optionally receive a LossState and add their loss to it
- class torchref.refinement.targets.base.Target(verbose=0, **kwargs)[source]
Bases:
DeviceMixin,ModuleAbstract base class for all target functions.
All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.
Supports empty initialization for state_dict loading:
target = Target() # Creates empty shell target.load_state_dict(torch.load('target.pt'))
- LossState Integration:
Targets can work with LossState for the new pipeline:
state = target.add_to_state(state) # Adds loss to state
- Parameters:
verbose (int, optional) – Verbosity level. Default is 0.
- __init__(verbose=0, **kwargs)[source]
Initialize target.
- Parameters:
verbose (int, optional) – Verbosity level. Default is 0.
- add_to_state(state)[source]
Compute loss and add it to the LossState.
This method enables the new LossState pipeline pattern where targets receive a state object, compute their loss, add it to the state, and return the state for chaining.
- maintenance()[source]
Between-step housekeeping hook (no-op by default).
LossStatecalls this on every registered target after each successful outer optimizer step returns. Targets override this to rebuild stale internal state (VDW pair lists, solvent masks, etc.) based on how far parameters have drifted since the last refresh.Contract
Must be idempotent: calling it multiple times in a row on an unchanged model should not mutate the target.
Fast path first: cheap staleness check up front, expensive rebuild only when strictly necessary.
LossStatecalls this every outer step — the happy-path cost is paid every time.Must not raise on routine drift. If a rebuild fails, let the exception propagate — that’s a real bug.
- class torchref.refinement.targets.base.ModelTarget(model=None, verbose=0, **kwargs)[source]
Bases:
TargetBase class for targets that only need a Model reference.
This class provides a simpler interface for geometry and ADP targets that don’t need access to reflection data or refinement machinery. Targets inherit from this class when they only need the atomic model.
The model is registered as a proper submodule, allowing PyTorch to handle device movement and state_dict operations automatically.
- Parameters:
- property restraints
Access model’s restraints (built lazily on first access).
- class torchref.refinement.targets.base.DataTarget(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]
Bases:
TargetBase class for targets that need ReflectionData and optionally Model/Scaler.
This class provides a flexible interface for X-ray targets that can work in two modes:
With Model: Computes F_calc from the model on each forward pass
Without Model: Uses pre-computed F_calc passed directly
This decoupling allows targets to be used for: - Standard refinement (with model) - Analysis/scoring of pre-computed structure factors (without model) - Testing and validation workflows
All objects (model, data, scaler) are registered as proper submodules, allowing PyTorch to handle device movement and state_dict operations.
- Parameters:
data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().
model (Model or ModelFT, optional) – Reference to a Model object for F_calc computation. If None, F_calc must be provided to forward().
scaler (Scaler, optional) – Reference to the Scaler object for scaling F_calc.
verbose (int, optional) – Verbosity level. Default is 0.
target_value (float, optional) – Target value for this loss. Default is 0.0.
sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.
- _data
Reference to the reflection data object (registered as submodule).
- Type:
- __init__(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]
Initialize data target.
- Parameters:
data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().
model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().
scaler (Scaler, optional) – Reference to the Scaler object.
verbose (int, optional) – Verbosity level. Default is 0.
- property data: ReflectionData
Access the reflection data object.
- get_fcalc(hkl=None, recalc=False)[source]
Compute structure factors from model.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.
recalc (bool, optional) – Force recalculation. Default is False.
- Returns:
Complex structure factors.
- Return type:
- Raises:
RuntimeError – If no model is set.
- get_fcalc_scaled(hkl=None, recalc=False, fcalc=None)[source]
Compute or scale structure factors.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.
recalc (bool, optional) – Force recalculation. Default is False.
fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.
- Returns:
Scaled complex structure factors.
- Return type:
- get_F_calc_scaled(hkl=None, recalc=False, fcalc=None)[source]
Compute scaled structure factor amplitudes.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.
recalc (bool, optional) – Force recalculation. Default is False.
fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.
- Returns:
Scaled structure factor amplitudes |F_calc|.
- Return type:
- get_rfactor()[source]
Compute R-factors using scaler.
- Returns:
(R_work, R_free) values.
- Return type:
- Raises:
RuntimeError – If no scaler is set.
- torchref.refinement.targets.base.gaussian_nll(deviations, sigmas)[source]
Compute Gaussian negative log-likelihood.
NLL = 0.5 * ((x - μ) / σ)² + log(σ) + 0.5 * log(2π)
- Parameters:
deviations (torch.Tensor) – Deviations from target values (x - μ).
sigmas (torch.Tensor) – Standard deviations.
- Returns:
Tensor of NLL values (same shape as input).
- Return type:
- torchref.refinement.targets.base.von_mises_nll(deviations_rad, sigmas_deg)[source]
Compute von Mises negative log-likelihood for angular data.
NLL = -κ*cos(θ) + log(I₀(κ)) + log(2π) where κ = 1/σ²
- Parameters:
deviations_rad (torch.Tensor) – Angular deviations in radians.
sigmas_deg (torch.Tensor) – Standard deviations in degrees.
- Returns:
Tensor of NLL values (same shape as input).
- Return type:
- torchref.refinement.targets.base.adp_similarity_nll(adp_diffs, sigma=2.0)[source]
Compute ADP similarity NLL (SIMU restraint).
- Parameters:
adp_diffs (torch.Tensor) – ADP differences between bonded atoms.
sigma (float, optional) – Target standard deviation. Default is 2.0 Ų.
- Returns:
Tensor of NLL values (same shape as input).
- Return type:
- torchref.refinement.targets.base.detach_phases(fcalc)[source]
Extract phases from complex structure factors with gradient detachment.
- Parameters:
fcalc (torch.Tensor) – Complex structure factors.
- Returns:
Detached phase angles in radians.
- Return type: