torchref.refinement.targets package

Refinement target functions for crystallographic structure refinement.

This module provides target (loss) functions for X-ray, geometry, and ADP restraints.

class torchref.refinement.targets.Target(verbose=0, **kwargs)[source]

Bases: DeviceMixin, Module

Abstract base class for all target functions.

All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.

Supports empty initialization for state_dict loading:

target = Target()  # Creates empty shell
target.load_state_dict(torch.load('target.pt'))
LossState Integration:

Targets can work with LossState for the new pipeline:

state = target.add_to_state(state)  # Adds loss to state
Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

name

Unique name for this target (used as loss key in LossState).

Type:

str

verbose

Verbosity level.

Type:

int

name: str = 'base_target'
__init__(verbose=0, **kwargs)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

add_to_state(state)[source]

Compute loss and add it to the LossState.

This method enables the new LossState pipeline pattern where targets receive a state object, compute their loss, add it to the state, and return the state for chaining.

Parameters:

state (LossState) – Current loss state with computed data.

Returns:

State with this target’s loss added.

Return type:

LossState

maintenance()[source]

Between-step housekeeping hook (no-op by default).

LossState calls this on every registered target after each successful outer optimizer step returns. Targets override this to rebuild stale internal state (VDW pair lists, solvent masks, etc.) based on how far parameters have drifted since the last refresh.

Contract

  • Must be idempotent: calling it multiple times in a row on an unchanged model should not mutate the target.

  • Fast path first: cheap staleness check up front, expensive rebuild only when strictly necessary. LossState calls this every outer step — the happy-path cost is paid every time.

  • Must not raise on routine drift. If a rebuild fails, let the exception propagate — that’s a real bug.

class torchref.refinement.targets.ModelTarget(model=None, verbose=0, **kwargs)[source]

Bases: Target

Base class for targets that only need a Model reference.

This class provides a simpler interface for geometry and ADP targets that don’t need access to reflection data or refinement machinery. Targets inherit from this class when they only need the atomic model.

The model is registered as a proper submodule, allowing PyTorch to handle device movement and state_dict operations automatically.

Parameters:
  • model (Model, optional) – Reference to the Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • target_value (float, optional) – Target value for this loss. Default is 0.0.

  • sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.

name

Unique name for this target (used as loss key in LossState).

Type:

str

_model

Reference to the model object (registered as submodule).

Type:

Model

verbose

Verbosity level.

Type:

int

name: str = 'model_target'
__init__(model=None, verbose=0, **kwargs)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

property model: Model

Access the model object.

property restraints

Access model’s restraints (built lazily on first access).

class torchref.refinement.targets.DataTarget(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]

Bases: Target

Base class for targets that need ReflectionData and optionally Model/Scaler.

This class provides a flexible interface for X-ray targets that can work in two modes:

  1. With Model: Computes F_calc from the model on each forward pass

  2. Without Model: Uses pre-computed F_calc passed directly

This decoupling allows targets to be used for: - Standard refinement (with model) - Analysis/scoring of pre-computed structure factors (without model) - Testing and validation workflows

All objects (model, data, scaler) are registered as proper submodules, allowing PyTorch to handle device movement and state_dict operations.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to a Model object for F_calc computation. If None, F_calc must be provided to forward().

  • scaler (Scaler, optional) – Reference to the Scaler object for scaling F_calc.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • target_value (float, optional) – Target value for this loss. Default is 0.0.

  • sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.

name

Unique name for this target (used as loss key in LossState).

Type:

str

_model

Reference to the model object (registered as submodule).

Type:

Model

_data

Reference to the reflection data object (registered as submodule).

Type:

ReflectionData

_scaler

Reference to the scaler object (registered as submodule).

Type:

Scaler

verbose

Verbosity level.

Type:

int

name: str = 'data_target'
__init__(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]

Initialize data target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • verbose (int, optional) – Verbosity level. Default is 0.

property model: Model

Access the model object.

property data: ReflectionData

Access the reflection data object.

property scaler: Scaler

Access the scaler object.

property has_model: bool

Check if a model is available for F_calc computation.

get_fcalc(hkl=None, recalc=False)[source]

Compute structure factors from model.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.

  • recalc (bool, optional) – Force recalculation. Default is False.

Returns:

Complex structure factors.

Return type:

torch.Tensor

Raises:

RuntimeError – If no model is set.

get_fcalc_scaled(hkl=None, recalc=False, fcalc=None)[source]

Compute or scale structure factors.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.

  • recalc (bool, optional) – Force recalculation. Default is False.

  • fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.

Returns:

Scaled complex structure factors.

Return type:

torch.Tensor

get_F_calc_scaled(hkl=None, recalc=False, fcalc=None)[source]

Compute scaled structure factor amplitudes.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.

  • recalc (bool, optional) – Force recalculation. Default is False.

  • fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.

Returns:

Scaled structure factor amplitudes |F_calc|.

Return type:

torch.Tensor

get_rfactor()[source]

Compute R-factors using scaler.

Returns:

(R_work, R_free) values.

Return type:

tuple

Raises:

RuntimeError – If no scaler is set.

torchref.refinement.targets.gaussian_nll(deviations, sigmas)[source]

Compute Gaussian negative log-likelihood.

NLL = 0.5 * ((x - μ) / σ)² + log(σ) + 0.5 * log(2π)

Parameters:
Returns:

Tensor of NLL values (same shape as input).

Return type:

torch.Tensor

torchref.refinement.targets.von_mises_nll(deviations_rad, sigmas_deg)[source]

Compute von Mises negative log-likelihood for angular data.

NLL = -κ*cos(θ) + log(I₀(κ)) + log(2π) where κ = 1/σ²

Parameters:
  • deviations_rad (torch.Tensor) – Angular deviations in radians.

  • sigmas_deg (torch.Tensor) – Standard deviations in degrees.

Returns:

Tensor of NLL values (same shape as input).

Return type:

torch.Tensor

torchref.refinement.targets.adp_similarity_nll(adp_diffs, sigma=2.0)[source]

Compute ADP similarity NLL (SIMU restraint).

Parameters:
  • adp_diffs (torch.Tensor) – ADP differences between bonded atoms.

  • sigma (float, optional) – Target standard deviation. Default is 2.0 Ų.

Returns:

Tensor of NLL values (same shape as input).

Return type:

torch.Tensor

class torchref.refinement.targets.XrayTarget(data=None, model=None, scaler=None, use_work_set=True, sigma_mode='raw', verbose=0)[source]

Bases: DataTarget

Base class for X-ray targets.

Provides common functionality for accessing F_obs, F_calc, etc. Supports two modes of operation:

  1. With Model: Computes F_calc from model on each forward pass

  2. Without Model: Uses pre-computed F_calc passed to forward()/get_data()

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, fcalc must be provided to forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • use_work_set (bool, optional) – If True, compute loss on work set; if False, on test set. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

use_work_set

Whether to use work set or test set.

Type:

bool

__init__(data=None, model=None, scaler=None, use_work_set=True, sigma_mode='raw', verbose=0)[source]

Initialize X-ray target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, fcalc must be provided to forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • use_work_set (bool, optional) – If True, compute loss on work set; if False, on test set. Default is True.

  • sigma_mode (str, optional) –

    Which sigma to use in the likelihood. Options:

    • 'raw' (default): use the raw experimental sigmas from the data file. Empirically gives the best Rfree across the mid-resolution regime (1.5-3.0 A) when paired with appropriate group weights.

    • 'effective': use per-shell effective sigmas estimated from scaling residuals (capped SIGMAA-style correction). Opt-in for high-resolution refinement (< 1.5 A) or datasets with known sigma miscalibration. Note: Scaler.estimate_sigma_eff is always called so the estimates are available regardless of which mode the target uses.

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'xray'
reset_get_data_cache()[source]

Drop the cached bookkeeping tensors.

Call this if you mutate self._data.log_scale / self._data.U_aniso in-place outside of the normal fingerprint- tracked flow, or if you want to free the memory.

get_data(fcalc=None)[source]

Get F_obs, F_calc, sigma, and centric flags for the appropriate set.

Bookkeeping tensors (F_obs_sel, sigma_sel, mask, centric_sel) are cached and reused as long as the upstream scaling parameters (log_scale, U_aniso) of the ReflectionData haven’t been mutated. Only F_calc_sel is recomputed from the live fcalc on each call.

Parameters:

fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, uses these instead of computing from model. Useful when model is not set.

Returns:

(F_obs_sel, F_calc_sel, sigma_sel, centric_sel, mask).

Return type:

tuple

stats(fcalc=None)[source]

Get statistics for this X-ray target.

Parameters:

fcalc (torch.Tensor, optional) – Pre-computed structure factors.

Returns:

Statistics dict with StatEntry values containing verbosity levels.

Return type:

dict

class torchref.refinement.targets.GaussianXrayTarget(data=None, model=None, scaler=None, use_work_set=True, sigma_mode='raw', verbose=0)[source]

Bases: XrayTarget

Simple Gaussian NLL target for X-ray data.

NLL = 0.5*(F_obs - |F_calc|)²/σ² + log(σ) + 0.5*log(2π)

target_value: float = 1.0
forward(fcalc=None)[source]

Compute Gaussian NLL loss.

Parameters:

fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, uses these instead of computing from model.

Returns:

Mean NLL loss value.

Return type:

torch.Tensor

class torchref.refinement.targets.MaximumLikelihoodXrayTarget(data=None, model=None, scaler=None, use_work_set=True, sigma_mode='raw', verbose=0)[source]

Bases: XrayTarget

Maximum Likelihood target function with proper centric/acentric handling.

forward(fcalc=None)[source]

Compute maximum likelihood loss.

Parameters:

fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, uses these instead of computing from model.

Returns:

Mean ML loss value.

Return type:

torch.Tensor

class torchref.refinement.targets.LeastSquaresXrayTarget(data=None, model=None, scaler=None, weighting='sigma', use_work_set=True, sigma_mode='raw', verbose=0)[source]

Bases: XrayTarget

Least Squares target function. L_LS = Σ w_i * (|F_obs| - k * |F_calc|

__init__(data=None, model=None, scaler=None, weighting='sigma', use_work_set=True, sigma_mode='raw', verbose=0)[source]

Initialize X-ray target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, fcalc must be provided to forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • use_work_set (bool, optional) – If True, compute loss on work set; if False, on test set. Default is True.

  • sigma_mode (str, optional) –

    Which sigma to use in the likelihood. Options:

    • 'raw' (default): use the raw experimental sigmas from the data file. Empirically gives the best Rfree across the mid-resolution regime (1.5-3.0 A) when paired with appropriate group weights.

    • 'effective': use per-shell effective sigmas estimated from scaling residuals (capped SIGMAA-style correction). Opt-in for high-resolution refinement (< 1.5 A) or datasets with known sigma miscalibration. Note: Scaler.estimate_sigma_eff is always called so the estimates are available regardless of which mode the target uses.

  • verbose (int, optional) – Verbosity level. Default is 0.

forward(fcalc=None)[source]

Compute least squares loss.

Parameters:

fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, uses these instead of computing from model.

Returns:

Mean weighted least squares loss.

Return type:

torch.Tensor

torchref.refinement.targets.create_xray_target(data=None, model=None, scaler=None, mode='gaussian', use_work_set=True, sigma_mode='raw', sigma_m_scale=1.0, verbose=0, device=None)[source]

Factory function to create X-ray target.

Parameters:
  • data (ReflectionData) – Reference to ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, fcalc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to Scaler object.

  • mode (str, optional) – Target mode: ‘gaussian’, ‘ls’, or ‘ml’. Default is ‘gaussian’.

  • use_work_set (bool, optional) – Use work set (True) or test set (False). Default is True.

  • sigma_mode (str, optional) – ‘effective’ (default) to use per-shell effective sigmas from the scaler (SIGMAA-style, robust), or ‘raw’ to use raw experimental sigmas from the data file.

  • verbose (int, optional) – Verbosity level. Default is 0.

Returns:

Appropriate XrayTarget instance.

Return type:

XrayTarget

class torchref.refinement.targets.DifferenceXrayTarget(dataset_collection=None, data_light=None, data_dark=None, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, use_work_set=True, verbose=0)[source]

Bases: Target

Target for time-resolved crystallography comparing light/dark states.

Computes difference structure factors and compares against observed differences:

Uses Gaussian NLL with proper error propagation:

  • σ_diff = sqrt(σ_light² + σ_dark²)

  • NLL = 0.5 * (ΔF_obs - ΔF_calc)² / σ_diff² + log(σ_diff) + 0.5*log(2π)

Supports two initialization modes:

  1. DatasetCollection mode (recommended): Pass a DatasetCollection with pre-aligned datasets. This is more efficient and ensures consistency with other targets using the same data.

  2. Separate datasets mode: Pass individual ReflectionData objects. HKL matching is performed automatically.

Parameters:
  • dataset_collection (DatasetCollection, optional) – Collection containing ‘dark’ and ‘light’ datasets (pre-aligned HKL). If provided, data_light and data_dark are ignored.

  • data_light (ReflectionData, optional) – Reflection data for the light (excited) state.

  • data_dark (ReflectionData, optional) – Reflection data for the dark (ground) state.

  • model_light (ModelFT or MixedModel) – Model for the light state structure factor calculation.

  • model_dark (ModelFT) – Model for the dark state structure factor calculation.

  • scaler_light (ScalerBase, optional) – Scaler for the light state F_calc. Can be shared with other targets.

  • scaler_dark (ScalerBase, optional) – Scaler for the dark state F_calc. Can be shared with other targets.

  • use_work_set (bool, optional) – If True, compute loss on work set. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

Using DatasetCollection (recommended for sharing scalers):

# Create collection with aligned HKL
collection = DatasetCollection()
collection.add_dataset('dark', data_dark, set_as_reference=True)
collection.add_dataset('light', data_light)

# Create shared scalers
scaler_dark = IsotropicScaler(data=collection['dark'], model=model_dark)
scaler_light = IsotropicScaler(data=collection['light'], model=model_mixed)

# Create targets that share scalers
xray_dark = GaussianXrayTarget(
    data=collection['dark'], model=model_dark, scaler=scaler_dark
)
xray_light = GaussianXrayTarget(
    data=collection['light'], model=model_mixed, scaler=scaler_light
)
diff_target = DifferenceXrayTarget(
    dataset_collection=collection,
    model_light=model_mixed,
    model_dark=model_dark,
    scaler_light=scaler_light,
    scaler_dark=scaler_dark,
)

# Combined loss
loss = xray_dark() + xray_light() + diff_target()

Using separate datasets:

diff_target = DifferenceXrayTarget(
    data_light=data_light,
    data_dark=data_dark,
    model_light=model_light,
    model_dark=model_dark,
)
loss = diff_target()

With mixed model for partial occupancy:

mixed_light = MixedModel([model_dark, model_light], [0.7, 0.3])
diff_target = DifferenceXrayTarget(
    dataset_collection=collection,
    model_light=mixed_light,
    model_dark=model_dark,
    scaler_light=scaler_light,
    scaler_dark=scaler_dark,
)
name: str = 'difference_xray'
__init__(dataset_collection=None, data_light=None, data_dark=None, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, use_work_set=True, verbose=0)[source]

Initialize DifferenceXrayTarget.

property dataset_collection

DatasetCollection if using collection mode.

property data_light: ReflectionData

Light state reflection data.

property data_dark: ReflectionData

Dark state reflection data.

property model_light: ModelFT

Light state model.

property model_dark: ModelFT

Dark state model.

property scaler_light: Scaler

Light state scaler.

property scaler_dark: Scaler

Dark state scaler.

property hkl: Tensor

Common HKL indices for both datasets.

Returns the aligned HKL from DatasetCollection if available, otherwise the matched HKL computed from separate datasets.

get_delta_F_obs()[source]

Get observed difference structure factors with error propagation.

Returns:

  • delta_F_obs (torch.Tensor) – ΔF_obs = F_light_obs - F_dark_obs

  • sigma_diff (torch.Tensor) – σ_diff = sqrt(σ_light² + σ_dark²)

  • mask (torch.Tensor) – Boolean mask for work/test set selection and valid data.

Return type:

Tuple[Tensor, Tensor, Tensor]

get_delta_F_calc(fcalc_light=None, fcalc_dark=None, recalc=False)[source]

Compute calculated difference structure factors.

ΔF_calc = |F_light_calc| - |F_dark_calc|

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

  • recalc (bool, optional) – Force recalculation if True. Default is False.

Returns:

ΔF_calc for all reflections (full size, use mask from get_delta_F_obs).

Return type:

torch.Tensor

forward(fcalc_light=None, fcalc_dark=None, recalc=False)[source]

Compute Gaussian NLL loss for difference structure factors.

NLL = 0.5 * (ΔF_obs - ΔF_calc)² / σ_diff² + log(σ_diff) + 0.5*log(2π)

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

  • recalc (bool, optional) – Force recalculation if True. Default is False.

Returns:

Mean NLL loss value.

Return type:

torch.Tensor

stats(fcalc_light=None, fcalc_dark=None)[source]

Get statistics for difference refinement.

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

Returns:

Statistics dict with correlation, R_diff, etc.

Return type:

dict

class torchref.refinement.targets.PhaseInformedDifferenceTarget(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, phase_source='difference', use_work_set=True, verbose=0)[source]

Bases: Target

Phase-informed difference target for time-resolved crystallography.

Uses model phases to create complex observed differences, then compares with calculated complex differences:

ΔF_calc = F_mixed_calc - F_dark_calc (complex) ΔF_obs_complex = ΔF_obs * exp(i * φ) (using model phases) Loss = |ΔF_obs_complex - ΔF_calc|² / σ_diff²

The phase source can be configured: - “dark”: Use dark model phases (stable reference) - “difference”: Use phase of calculated difference ΔF_calc (self-consistent) - “mixed”: Use mixed/light model phases

Using current model phases is standard practice in difference Fourier methods. The iterative nature of refinement self-corrects any phase bias, and the localized nature of difference peaks allows detection of weak signals.

Parameters:
  • dataset_collection (DatasetCollection) – Collection containing ‘dark’ and ‘light’ datasets.

  • model_light (ModelFT or MixedModel) – Model for the light/excited state.

  • model_dark (ModelFT) – Model for the dark/ground state.

  • scaler_light (Scaler, optional) – Scaler for light state F_calc.

  • scaler_dark (Scaler, optional) – Scaler for dark state F_calc.

  • phase_source (str, optional) – Source for phases: “dark”, “difference”, or “mixed”. Default is “difference”.

  • use_work_set (bool, optional) – If True, compute loss on work set only. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

Using difference phases (recommended):

target = PhaseInformedDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
    phase_source="difference",
)

Using dark phases:

target = PhaseInformedDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
    phase_source="dark",
)
name: str = 'phase_informed_difference'
__init__(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, phase_source='difference', use_work_set=True, verbose=0)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

property hkl: Tensor

Common HKL indices.

forward(fcalc_light=None, fcalc_dark=None, recalc=True)[source]

Compute phase-informed difference loss.

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

  • recalc (bool, optional) – Force recalculation if True. Default is True.

Returns:

Mean weighted squared error.

Return type:

torch.Tensor

stats(fcalc_light=None, fcalc_dark=None)[source]

Get statistics for the difference refinement.

Returns:

Dictionary with loss, correlation, R_diff, etc.

Return type:

dict

class torchref.refinement.targets.RiceDifferenceTarget(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, use_work_set=True, verbose=0)[source]

Bases: Target

Rice-distribution difference target for time-resolved crystallography.

Works in complex space by grafting detached model phases onto observed amplitudes, then taking the complex difference. The magnitude of this complex difference is always non-negative, enabling a proper Rice distribution likelihood.

The procedure:

  1. Reconstruct complex observed structure factors using detached model phases:

    F_obs_light_complex = F_obs_light * exp(i * φ_calc_light)
    F_obs_dark_complex  = F_obs_dark  * exp(i * φ_calc_dark)
    
  2. Form complex differences:

    ΔF_obs_complex = F_obs_light_complex - F_obs_dark_complex
    ΔF_calc        = F_calc_light - F_calc_dark
    
  3. Compute strictly positive amplitudes:

    A_obs = |ΔF_obs_complex|   (always ≥ 0)
    ν     = |ΔF_calc|          (always ≥ 0)
    
  4. Apply Rice distribution NLL:

    NLL = -log(A) + log(σ²) + (A² + ν²)/(2σ²)
          - log(I₀(A·ν/σ²))
    

The Rice distribution naturally models the magnitude of a complex signal plus Gaussian noise, making it statistically appropriate for comparing amplitudes that are always positive by construction.

Parameters:
  • dataset_collection (DatasetCollection) – Collection containing ‘dark’ and ‘light’ datasets.

  • model_light (ModelFT or MixedModel) – Model for the light/excited state.

  • model_dark (ModelFT) – Model for the dark/ground state.

  • scaler_light (Scaler, optional) – Scaler for light state F_calc.

  • scaler_dark (Scaler, optional) – Scaler for dark state F_calc.

  • use_work_set (bool, optional) – If True, compute loss on work set only. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

Basic usage:

target = RiceDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
)

With scalers:

target = RiceDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
    scaler_light=scaler_light,
    scaler_dark=scaler_dark,
)
name: str = 'rice_difference'
__init__(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, use_work_set=True, verbose=0)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

property hkl: Tensor

Common HKL indices.

forward(fcalc_light=None, fcalc_dark=None, recalc=True)[source]

Compute Rice distribution NLL loss for difference structure factors.

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

  • recalc (bool, optional) – Force recalculation if True. Default is True.

Returns:

Mean Rice NLL loss value.

Return type:

torch.Tensor

compute_free_metrics(fcalc_light=None, fcalc_dark=None)[source]

Compute loss and correlation on the FREE (test) set.

Returns:

Dictionary with ‘free_loss’ and ‘free_correlation’.

Return type:

dict

stats(fcalc_light=None, fcalc_dark=None)[source]

Get statistics for the Rice difference refinement.

Returns:

Dictionary with loss, correlation, R_diff, etc.

Return type:

dict

class torchref.refinement.targets.TaylorCorrectedDifferenceTarget(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, use_work_set=True, verbose=0)[source]

Bases: Target

Taylor-corrected difference target for time-resolved crystallography.

Uses an exact Taylor expansion to properly account for the phase shift between dark and light states when constructing observed complex differences:

ΔF_obs = exp(i*φ_dark) * [F_obs_dark * (exp(i*dφ) - 1) + dF_obs * exp(i*dφ)]

Where:
  • dφ = φ_light_calc - φ_dark_calc (phase rotation from model)

  • dF_obs = F_obs_light - F_obs_dark (observed amplitude difference)

This formulation:
  1. Uses the exact complex exponential (no small-angle approximation)

  2. Properly accounts for both the amplitude difference and phase rotation

  3. Eliminates the false minimum that causes refinement to stop at ~70%

The loss is computed as:

Loss = |ΔF_obs_corrected - ΔF_calc|² / σ_diff²

Parameters:
  • dataset_collection (DatasetCollection) – Collection containing ‘dark’ and ‘light’ datasets.

  • model_light (ModelFT or MixedModel) – Model for the light/excited state.

  • model_dark (ModelFT) – Model for the dark/ground state.

  • scaler_light (Scaler, optional) – Scaler for light state F_calc.

  • scaler_dark (Scaler, optional) – Scaler for dark state F_calc.

  • use_work_set (bool, optional) – If True, compute loss on work set only. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

Basic usage:

target = TaylorCorrectedDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
)

With scalers:

target = TaylorCorrectedDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
    scaler_light=scaler_light,
    scaler_dark=scaler_dark,
)
name: str = 'taylor_corrected_difference'
__init__(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, use_work_set=True, verbose=0)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

property hkl: Tensor

Common HKL indices.

forward(fcalc_light=None, fcalc_dark=None, recalc=True)[source]

Compute Taylor-corrected difference loss.

The observed complex difference is constructed using the exact Taylor expansion:

ΔF_obs = exp(i*φ_dark) * [F_obs_dark * (exp(i*dφ) - 1) + dF_obs * exp(i*dφ)]

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

  • recalc (bool, optional) – Force recalculation if True. Default is True.

Returns:

Mean weighted squared error.

Return type:

torch.Tensor

compute_free_metrics(fcalc_light=None, fcalc_dark=None)[source]

Compute loss and correlation on the FREE (test) set.

This is the key metric for detecting overfitting in the α-δF degeneracy. The correct solution should have better free set metrics.

Returns:

Dictionary with ‘free_loss’ and ‘free_correlation’.

Return type:

dict

stats(fcalc_light=None, fcalc_dark=None)[source]

Get statistics for the difference refinement.

Returns:

Dictionary with loss, correlation, R_diff, etc.

Return type:

dict

class torchref.refinement.targets.GeometryTarget(model=None, verbose=0, **kwargs)[source]

Bases: ModelTarget

Base class for geometry restraint targets.

Geometry targets access the model’s restraints property (built lazily) to compute losses for bonds, angles, torsions, planes, etc.

Parameters:
  • model (Model, optional) – Reference to the Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • target_value (float, optional) – Target value for this loss. Default is -1.0.

  • sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.

__init__(model=None, verbose=0, **kwargs)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

stats()[source]

Get statistics for this restraint type.

Returns dict with StatEntry values. Filter with filter_stats() at display time.

Returns:

Statistics dict with StatEntry values containing verbosity levels.

Return type:

dict

class torchref.refinement.targets.BondTarget(model=None, verbose=0)[source]

Bases: GeometryTarget

Bond length restraint target (Gaussian NLL).

NLL = 0.5 * ((d - d₀) / σ)² + log(σ) + 0.5 * log(2π)

name: str = 'geometry/bond'
__init__(model=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get bond restraint statistics.

class torchref.refinement.targets.AngleTarget(model=None, verbose=0)[source]

Bases: GeometryTarget

Angle restraint target (Gaussian NLL).

NLL = 0.5 * ((θ - θ₀) / σ)² + log(σ) + 0.5 * log(2π)

name: str = 'geometry/angle'
__init__(model=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get angle restraint statistics.

class torchref.refinement.targets.TorsionTarget(model=None, verbose=0, w_cis_proline=0.05, w_cis_general=0.0005)[source]

Bases: GeometryTarget

Torsion angle restraint target.

Handles all torsion restraints in one target:

  • Intra-residue & disulfide torsions: unimodal von Mises NLL with periodicity handling.

  • Omega (peptide bond) torsions: cis/trans von Mises mixture, so both cis and trans conformations are stable wells and the X-ray data decides which one to adopt.

Parameters:
  • model (Model, optional) – Reference to the Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • w_cis_proline (float, optional) – Prior weight for cis conformation at pre-proline peptide bonds. Default 0.05 (~5% of X-Pro bonds are cis in the PDB).

  • w_cis_general (float, optional) – Prior weight for cis conformation at non-proline peptide bonds. Default 0.0005 (~0.05% of non-Pro bonds are cis in the PDB).

name: str = 'geometry/torsion'
__init__(model=None, verbose=0, w_cis_proline=0.05, w_cis_general=0.0005)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get torsion angle statistics.

class torchref.refinement.targets.PlanarityTarget(model=None, verbose=0)[source]

Bases: GeometryTarget

Planarity restraint target (Gaussian NLL).

For each planar group (e.g., aromatic rings, peptide planes), computes the distance of each atom from the best-fit plane.

The best-fit plane normal is found by eigendecomposition of the 3x3 covariance matrix of centered coordinates (eigh). The normal is detached from the computational graph so that gradients flow only through the deviation projection, not through the eigendecomposition. This is standard practice in crystallographic refinement (SHELXL, Phenix, Refmac) and is more numerically robust than differentiating through SVD — in particular it avoids NaN gradients when atoms are exactly coplanar.

Plane groups with <= 3 atoms are skipped since 3 coplanar points have zero deviation by construction and contribute no gradient signal.

NLL = 0.5 * (d_i / σ_i)² + log(σ_i) + 0.5 * log(2π)

where d_i is the distance of atom i from the best-fit plane.

name: str = 'geometry/planarity'
__init__(model=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get planarity restraint statistics.

class torchref.refinement.targets.ChiralTarget(model=None, verbose=0)[source]

Bases: GeometryTarget

Chiral volume restraint target.

Restrains the signed volume of tetrahedral chiral centers to maintain correct stereochemistry (R vs S configuration, L vs D amino acids).

The chiral volume is computed as:

V = v1 · (v2 × v3)

where vi = position of neighbor i - position of center.

For standard protein Cα atoms with ordering (N, C, CB): - L-amino acids: positive volume (~+2.5 ų) - D-amino acids: negative volume (~-2.5 ų)

The loss function penalizes deviations from the ideal signed volume:

NLL = 0.5 * ((V - V_ideal) / σ)² + log(σ) + 0.5 * log(2π)

For achiral centers (volume_sign=’both’), we restrain the absolute volume.

name: str = 'geometry/chiral'
__init__(model=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

get_violations(threshold=0.5)[source]

Get information about chiral volume violations.

Parameters:

threshold (float, optional) – Report deviations larger than this (ų). Default is 0.5.

Returns:

Dictionary with ‘indices’, ‘volumes’, ‘ideal_volumes’, ‘deviations’.

Return type:

dict

stats()[source]

Get chiral volume statistics.

class torchref.refinement.targets.NonBondedTarget(model=None, mode='prolsq', sigma=0.3, r_exp=4.0, c_rep=None, buffer=0.0, rebuild_threshold=1.0, verbose=0, scale=10.0)[source]

Bases: GeometryTarget

Non-bonded (van der Waals) restraint target using PROLSQ-style repulsion, parameterized as a generalized-Gaussian NLL on the overlap with scale \(\sigma\).

Per-pair NLL (mode='prolsq'):

\[\mathrm{NLL}(v) \;=\; \frac{v^{p}}{p\,\sigma^{p}} \;+\; \log\sigma \;+\; \tfrac{1}{2}\log(2\pi) \qquad v \;=\; \max\!\bigl(0,\, d_{\text{vdw}} + b - d\bigr)\]

where \(p = r_\text{exp}\) (default 4). The shape term \(v^p/(p\sigma^p)\) is algebraically identical to the classical PROLSQ energy \(c_{\text{rep}}\,v^p\) with \(c_{\text{rep}} = 1/(p\,\sigma^{p})\); exposing \(\sigma\) makes the physics legible (σ is an “effective tolerance” on the overlap) and puts the VDW loss on the same NLL footing as bond / angle / planarity.

Default sigma = 0.3 Å. A practical middle ground: stiff enough to reliably pull medium-large clashes (~0.4–0.5 Å) out of the refinement but loose enough that starting from a shaken or rough model doesn’t generate LBFGS-destabilising gradients. A 0.4 Å MolProbity clash sits at ~1.3σ and contributes ~0.8 NLL units; a 0.5 Å severe clash sits at ~1.7σ and contributes ~1.9 NLL units. The classical PROLSQ strength c_rep=16, r_exp=4 is equivalent to \(\sigma \approx 0.354\ \text{\AA}\), very close to this default. Set sigma explicitly for deliberate tightening (e.g. 0.13 Å → 3σ clash, ~20 NLL units) or loosening.

Alternative modes:

  • 'prolsq': generalized-Gaussian NLL with exponent r_exp (default)

  • 'gaussian': Gaussian NLL on the overlap using per-pair sigmas

  • 'soft': soft repulsion with linear core outside threshold

When symmetry information is available (cell and spacegroup on the model), also handles contacts between ASU atoms and symmetry-related copies. Symmetry mate positions are recomputed on-the-fly from current ASU coordinates so that gradients flow to both atoms in each pair.

Reference: cctbx/geometry_restraints/nonbonded.h, PROLSQ documentation, MolProbity clash criterion (Davis et al., NAR 2007).

Parameters:
  • model (Model, optional) – Reference to Model object.

  • mode (str, optional) – Repulsion function type (‘prolsq’, ‘gaussian’, ‘soft’). Default is ‘prolsq’.

  • sigma (float, optional) – Effective tolerance on the overlap in Angstroms. Default is 0.3.

  • r_exp (float, optional) – Exponent of the repulsion term. Default is 4.0.

  • c_rep (float, optional) – Back-door repulsion coefficient. If provided, overrides the sigma-derived value and the NLL becomes \(c_{\text{rep}} v^{r_\text{exp}} + \log\sigma + \tfrac{1}{2}\log(2\pi)\). Useful for reproducing legacy PROLSQ weights. Default is None (derive from sigma).

  • buffer (float, optional) – Distance buffer in Angstroms added to VDW radii sum. Shifts the repulsion onset outward so atoms feel repulsion before they clash. Default is 0.0.

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'geometry/nonbonded'
__init__(model=None, mode='prolsq', sigma=0.3, r_exp=4.0, c_rep=None, buffer=0.0, rebuild_threshold=1.0, verbose=0, scale=10.0)[source]

Initialize non-bonded target.

Parameters:
  • model (Model, optional) – Reference to Model object.

  • mode (str, optional) – Repulsion function type (‘prolsq’, ‘gaussian’, ‘soft’). Default is ‘prolsq’.

  • sigma (float, optional) – Effective tolerance on the overlap (Å). Default 0.3. Only used when c_rep is None.

  • r_exp (float, optional) – Repulsion exponent. Default is 4.0.

  • c_rep (float or None, optional) – Legacy coefficient override. If None (default), derived from sigma as 1 / (r_exp * sigma ** r_exp).

  • buffer (float, optional) – Distance buffer in Angstroms added to VDW radii sum. Default is 0.0.

  • rebuild_threshold (float, optional) – Maximum ASU atom displacement in Angstroms since the last VDW pair-list build before maintenance() triggers a rebuild. Default is 1.0 Å — well inside the ~2.4 Å safety margin of the default 6.0 Å cutoff, so newly-formed contacts cannot slip through the list.

  • verbose (int, optional) – Verbosity level. Default is 0.

property c_rep: float

Get repulsion coefficient.

property sigma_vdw: float

Get the effective overlap tolerance sigma (Å).

property r_exp: float

Get repulsion exponent.

property buffer: float

Get distance buffer.

maintenance()[source]

Rebuild the VDW pair list if any ASU atom drifted too far.

Fast path: one max().item() sync on the per-atom displacement norm between the current ASU coordinates and the snapshot taken at the last VDW build. If the max displacement stays within _rebuild_threshold we return immediately.

Slow path (only when triggered): delegate to restraints.rebuild_vdw_restraints which refreshes the pair list using the original build kwargs and updates the snapshot. See Target.maintenance() for the general contract.

Safety invariant: the default build cutoff (6.0 Å) leaves roughly cutoff - max_vdw_sum 2.4 Å of slack before a previously- non-contact atom pair could form a new clash. Setting rebuild_threshold < 2.4 / 2 = 1.2 Å guarantees that no such pair can slip through the list — that is, a rebuild fires before the slack is consumed.

forward()[source]

Compute and return the loss. Override in subclasses.

get_violations(threshold=0.0)[source]

Get information about VDW violations.

Parameters:

threshold (float, optional) – Only report violations greater than this (Å). Default is 0.0.

Returns:

Dictionary with ‘indices’, ‘violations’, ‘distances’, ‘min_distances’.

Return type:

dict

stats()[source]

Get non-bonded restraint statistics.

class torchref.refinement.targets.NonBondedHTarget(model=None, mode='prolsq', sigma=0.3, r_exp=4.0, c_rep=None, buffer=0.0, rebuild_threshold=1.0, verbose=0)[source]

Bases: NonBondedTarget

Non-bonded target with transient riding hydrogen VDW contacts.

Drop-in replacement for NonBondedTarget. The heavy-heavy VDW loss is computed by the parent class; this subclass adds an H-VDW term from precomputed candidate H-heavy pairs.

Candidate pairs are derived at build time from the heavy-heavy VDW pair list. At forward time, only H placement + vectorized distance computation is needed — no spatial hashing.

Uses the same generalized-Gaussian NLL as the parent class; see NonBondedTarget for the sigma calibration.

Parameters:
  • model (Model, optional) – Reference to Model object.

  • mode (str, optional) – Repulsion function type. Default 'prolsq'.

  • sigma (float, optional) – Effective tolerance on the overlap (Å). Default 0.3.

  • r_exp (float, optional) – Repulsion exponent. Default 4.0.

  • c_rep (float or None, optional) – Legacy coefficient override; derived from sigma when None.

  • buffer (float, optional) – Distance buffer (Å). Default 0.0.

  • verbose (int, optional) – Verbosity level. Default 0.

name: str = 'geometry/nonbonded'
__init__(model=None, mode='prolsq', sigma=0.3, r_exp=4.0, c_rep=None, buffer=0.0, rebuild_threshold=1.0, verbose=0)[source]

Initialize non-bonded target.

Parameters:
  • model (Model, optional) – Reference to Model object.

  • mode (str, optional) – Repulsion function type (‘prolsq’, ‘gaussian’, ‘soft’). Default is ‘prolsq’.

  • sigma (float, optional) – Effective tolerance on the overlap (Å). Default 0.3. Only used when c_rep is None.

  • r_exp (float, optional) – Repulsion exponent. Default is 4.0.

  • c_rep (float or None, optional) – Legacy coefficient override. If None (default), derived from sigma as 1 / (r_exp * sigma ** r_exp).

  • buffer (float, optional) – Distance buffer in Angstroms added to VDW radii sum. Default is 0.0.

  • rebuild_threshold (float, optional) – Maximum ASU atom displacement in Angstroms since the last VDW pair-list build before maintenance() triggers a rebuild. Default is 1.0 Å — well inside the ~2.4 Å safety margin of the default 6.0 Å cutoff, so newly-formed contacts cannot slip through the list.

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

get_violations(threshold=0.0)[source]

Get violations including H-involving contacts.

stats()[source]

Get statistics including H-VDW contacts.

class torchref.refinement.targets.RamachandranTarget(model=None, verbose=0)[source]

Bases: GeometryTarget

Ramachandran restraint via pre-computed NLL surfaces.

Uses 6 residue-type-dependent NLL surfaces (general, glycine, cis-proline, trans-proline, pre-proline, ile/val) at 1-degree resolution. The loss is computed by bilinear interpolation of the NLL surface at the current (phi, psi) angles.

The surfaces store NLL = -log P(phi, psi | residue_type), so favored regions have low values and outlier regions have high values — consistent with all other geometry targets.

name: str = 'geometry/ramachandran'
__init__(model=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get Ramachandran restraint statistics.

class torchref.refinement.targets.ADPTarget(model=None, verbose=0, **kwargs)[source]

Bases: ModelTarget

Base class for ADP restraint targets.

ADP targets access the model’s ADP values and restraints for similarity, rigid bond, and other ADP-related restraints.

Parameters:
  • model (Model, optional) – Reference to the Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • target_value (float, optional) – Target value for this loss. Default is -1.0.

  • sigma (float, optional) – Sigma parameter for weighting. Default is 1.0.

__init__(model=None, verbose=0, **kwargs)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

class torchref.refinement.targets.ADPSimilarityTarget(model=None, simu_sigma=2.0, verbose=0)[source]

Bases: ADPTarget

ADP Similarity restraint (SIMU in Phenix/SHELX).

Restrains B-factors of bonded atoms to be similar. NLL = 0.5 * ((B_i - B_j) / σ)² + log(σ) + 0.5 * log(2π)

Tunable parameters (as buffers): - _simu_sigma: float, sigma for B-factor differences (default 2.0 Ų)

name: str = 'adp/simu'
__init__(model=None, simu_sigma=2.0, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

property simu_sigma: float

Get SIMU sigma value.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get SIMU restraint statistics.

class torchref.refinement.targets.RigidBondTarget(model=None, sigma=0.004, use_aniso=True, verbose=0)[source]

Bases: ADPTarget

Rigid Bond restraint (DELU in SHELX, Hirshfeld test).

Based on Hirshfeld’s rigid bond test (Acta Cryst. A32, 239, 1976).

For a truly rigid bond, the mean-square displacement amplitudes (MSDA) of the two bonded atoms along the bond direction should be equal. This is because in a rigid bond, the atoms move together.

For anisotropic ADPs (U tensors):

z_12 = l_12^T U_1 l_12 / |l_12|²  (MSDA of atom 1 along bond)
z_21 = l_21^T U_2 l_21 / |l_21|²  (MSDA of atom 2 along bond)
Δz = z_12 - z_21 should be ~0

For isotropic B-factors, the difference in B_iso is used as a proxy:

ΔB = B_1 - B_2

This differs from SIMU (ADPSimilarityTarget) which restrains the full ADP tensors to be similar. Rigid bond only restrains the component along the bond direction.

Energy: E = w * Δz² NLL: NLL = 0.5 * (Δz / σ)² + log(σ) + 0.5 * log(2π)

References

  • Hirshfeld, F.L. (1976). Acta Cryst. A32, 239.

  • cctbx/adp_restraints/rigid_bond.h

Parameters:
  • model (Model) – Reference to Model object.

  • sigma (float, optional) – Target standard deviation for Δz. Default is 0.004 Ų. Hirshfeld found typical values of 0.001 Ų for good structures.

  • use_aniso (bool, optional) – If True and model has anisotropic ADPs, use proper tensor calculation. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'adp/delu'
__init__(model=None, sigma=0.004, use_aniso=True, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute rigid bond restraint.

For isotropic refinement, uses B-factor differences along bonds. For anisotropic refinement, computes proper MSDA differences.

get_delta_z_stats()[source]

Get statistics of Δz values for analysis.

Returns:

Dictionary with mean, std, max, min of |Δz| values and Z-scores.

Return type:

dict

stats()[source]

Get rigid bond restraint statistics.

Returns statistics including Δz values along bonds.

class torchref.refinement.targets.ADPEntropyTarget(model=None, verbose=0)[source]

Bases: ADPTarget

ADP Entropy regularization target.

Uses the model’s existing adp_kl_divergence_loss or similar.

name: str = 'adp/KL'
__init__(model=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

stats()[source]

Get KL divergence statistics.

class torchref.refinement.targets.ADPLocalityTarget(model=None, k_neighbors=50, correlation_length=5.0, scale=5.0, exclude_bonded=True, verbose=0)[source]

Bases: ADPTarget

Proximity-based ADP restraint using K nearest neighbors.

Uses a spatial cell-list (O(N) memory, O(N·k) time) instead of a full N×N distance matrix, so it scales to arbitrarily large structures without memory issues.

Parameters:
  • model (Model) – Reference to Model object.

  • k_neighbors (int, optional) – Number of nearest neighbors to consider. Default is 50.

  • correlation_length (float, optional) – Distance scale for weight decay in Angstrom. Default is 5.0.

  • scale (float, optional) – Scaling factor for loss magnitude. Default is 5.0.

  • exclude_bonded (bool, optional) – Exclude directly bonded atoms. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'adp/locality'
__init__(model=None, k_neighbors=50, correlation_length=5.0, scale=5.0, exclude_bonded=True, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

property k_neighbors: int
property correlation_length: float
property scale: float
forward(recompute_neighbors=False)[source]

Compute weighted MSE on log(B) differences with inverse-distance weights.

loss = scale * mean_ij [w_ij * (log(B_i) - log(B_j))^2] where w_ij = 1 / (d_ij + eps)

stats()[source]

Get locality restraint statistics.

class torchref.refinement.targets.CombinedTargets(verbose=0)[source]

Bases: Target

Base class for combined targets.

Uses nn.ModuleDict to store component targets for clean organization and easy access via dictionary-style notation.

Subclasses should override _create_targets() to define their component targets.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

_targets

Dictionary of component targets.

Type:

nn.ModuleDict

__init__(verbose=0)[source]

Initialize CombinedTargets.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

targets()[source]

Return registered sub-targets as ModuleDict.

__getitem__(key)[source]

Get a target by name using dictionary-style access.

__contains__(key)[source]

Check if a target exists.

keys()[source]

Return target names.

values()[source]

Return target instances.

items()[source]

Return (name, target) pairs.

target_losses()[source]

Get individual component losses (without weights).

forward()[source]

Compute total combined target loss.

stats()[source]

Get statistics from all registered targets.

get()[source]

Get individual component losses.

add_to_state(state)[source]

Compute loss and add it to the LossState.

This method enables the new LossState pipeline pattern where targets receive a state object, compute their loss, add it to the state, and return the state for chaining.

Parameters:

state (LossState) – Current loss state with computed data.

Returns:

State with this target’s loss added.

Return type:

LossState

class torchref.refinement.targets.TotalGeometryTarget(model=None, verbose=0)[source]

Bases: CombinedModelTargets

Computes weighted sum of all geometry restraint NLLs.

Uses nn.ModuleDict to store component targets: - ‘bond’: BondTarget - ‘angle’: AngleTarget - ‘torsion’: TorsionTarget - ‘planarity’: PlanarityTarget - ‘chiral’: ChiralTarget - ‘nonbonded’: NonBondedHTarget (includes riding hydrogen VDW)

The torsion weight is reduced because:

  1. Protein torsions naturally deviate from ideal (Ramachandran plot)

  2. Side chain rotamers have discrete populations, not single ideals

  3. High torsion weight can over-constrain the structure

The nonbonded weight is very low because:

  1. PROLSQ repulsion is already steep (E ~ violation^4)

  2. Most contacts should be satisfied by covalent geometry

  3. High VDW weight can prevent proper packing

Set weight to 0 to disable a component.

Parameters:
  • model (Model, optional) – Reference to Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

geom_target = TotalGeometryTarget(model)
loss = geom_target()
bond_loss = geom_target['bond']()
for name, target in geom_target.items():
    print(f"{name}: {target()}")
get_metrics(verbosity=2)[source]

Get all geometry metrics as a flat dictionary for logging/reporting.

Parameters:

verbosity (int, optional) – Verbosity level for filtering. Default is VERBOSITY_DETAILED.

Returns:

Dictionary with validation metrics from all component targets. All values are Python floats (not tensors).

Return type:

dict

print_statistics()[source]

Print REFMAC-style geometry statistics with losses.

class torchref.refinement.targets.TotalADPTarget(model=None, verbose=0)[source]

Bases: CombinedModelTargets

Total ADP restraint target combining global, similarity, and local components.

Uses nn.ModuleDict to store component targets: - ‘simu’: ADPSimilarityTarget (SIMU-like bond similarity) - ‘locality’: ADPLocalityTarget (spatial smoothness) - ‘KL’: ADPEntropyTarget (KL divergence regularization)

B-factors follow a LOG-NORMAL distribution (B > 0, right-skewed). If B ~ LogNormal(μ, σ), then log(B) ~ Normal(μ, σ).

This target combines:

  1. Similarity restraint (SIMU-like): Bond-based B-factor similarity - Enforces bonded atoms have similar B-factors - Based on covalent bond topology (strongest local constraint)

  2. Locality restraint: Spatial smoothness - nearby atoms should have similar B - Uses K-NN with distance-based sigma (d² scaling) - Medium-range spatial correlation

  3. KL divergence: Controls the spread of B-factor distribution - Prevents overfitting by controlling distribution width

Log-normal distribution properties:

  • If log(B) ~ N(μ, σ), then:

  • Mean of B: exp(μ + σ²/2)

  • Mode of B: exp(μ - σ²)

  • For typical proteins: σ_logB ≈ 0.3-0.5 (in log space)

Parameters:
  • model (Model) – Reference to Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

adp_target = TotalADPTarget(model)
loss = adp_target()
simu_loss = adp_target['simu']()
for name, target in adp_target.items():
    print(f"{name}: {target()}")
print_statistics()[source]

Print comprehensive ADP restraint statistics.

Displays statistics from all registered ADP targets.

get_metrics(verbosity=2)[source]

Get all ADP metrics as a flat dictionary for logging/reporting.

Parameters:

verbosity (int, optional) – Verbosity level for filtering. Default is VERBOSITY_DETAILED.

Returns:

Dictionary with validation metrics from all component targets. All values are Python floats (not tensors).

Return type:

dict

class torchref.refinement.targets.ForceFieldTarget(model=None, model_path=None, cutoff=5.0, normalize_by_atoms=True, verbose=0)[source]

Bases: ModelTarget

Force field energy target using TorchMD-Net ML potentials.

Computes molecular energy from atomic coordinates using a pre-trained neural network potential. Returns energy as a differentiable tensor suitable for gradient-based refinement.

Parameters:
  • model (Model, optional) – Reference to the Model object. Must have hydrogens (load with strip_H=False).

  • model_path (str, optional) – Path to TorchMD-Net checkpoint file (.ckpt).

  • cutoff (float, optional) – Interaction cutoff distance in Angstroms. Default is 5.0.

  • normalize_by_atoms (bool, optional) – If True, return energy per atom. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

>>> from torchref.model import Model
>>> from torchref.refinement.targets import ForceFieldTarget
>>>
>>> # Load model WITH hydrogens
>>> model = Model(strip_H=False)
>>> model.load_pdb('structure_with_H.pdb')
>>>
>>> # Create force field target
>>> ff_target = ForceFieldTarget(
...     model=model,
...     model_path='path/to/torchmdnet.ckpt',
... )
>>>
>>> # Get energy
>>> energy = ff_target()

Notes

Requires torchmd-net package: pip install torchmd-net

Pre-trained models available at: https://github.com/torchmd/torchmd-net/tree/main/examples

name: str = 'forcefield'
__init__(model=None, model_path=None, cutoff=5.0, normalize_by_atoms=True, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute force field energy for current model coordinates.

Returns:

Scalar energy tensor with gradient support.

Return type:

torch.Tensor

stats()[source]

Get statistics for this target.

Returns:

Dictionary with StatEntry values.

Return type:

dict

class torchref.refinement.targets.AmberTarget(model=None, cutoff=5.0, normalize_by_atoms=True, residue_charges=None, gaff2_files=None, verbose=0)[source]

Bases: ModelTarget

Differentiable AMBER14/GAFF2 force-field energy restraint.

On construction the target:

  1. Detects non-standard residues (HETATM not in AMBER14_STANDARD).

  2. Runs antechamber + parmchk2 (parallel, cached) for each non-standard residue.

  3. Builds an OpenMM system:

    • Standard path (no non-standard residues): filter model PDB to primary conformation + heavy atoms, use openmm.app.Modeller to re-add H with AMBER14-compatible names, create system with ForceField('amber14-all.xml').

    • GAFF2 path (with non-standard residues): same protein PDB (additionally removing OXT) handed to tleap together with each ligand’s mol2 via combine{}. Combined AMBER14+GAFF2 topology is parameterised by parmed.

  4. Creates an OpenMM Context on the platform that matches the model’s device: CUDA for model.device.type == 'cuda', CPU otherwise. Falls back CUDA → OpenCL → CPU if the preferred platform is unavailable.

  5. Builds a model-atom → OpenMM-atom index map so that only heavy atoms are transferred; H positions are kept from the initial OpenMM setup.

Parameters:
  • model (Model) –

    TorchRef model. Heavy-atom-only models (strip_H=True) are accepted. H atoms are added internally by OpenMM’s Modeller or tleap and are NOT included in the atom map or gradient.

    Passing a model that already has H atoms (via model.generate_hydrogens() or loading a PDB with H) speeds up initialisation ~4× because Modeller.addHydrogens() converges faster from existing positions.

    Required for GAFF2 ligands: antechamber’s BCC charge scheme runs a semiempirical QM step (sqm) that requires a fully protonated molecule. If the model has no H atoms for a non-standard residue, an explicit error is raised. Call model.generate_hydrogens() or load the PDB with strip_H=False before creating the target.

  • cutoff (float) – Non-bonded cutoff in Angstroms. Default 5.0.

  • normalize_by_atoms (bool) – If True the energy is divided by the number of model atoms. Default True.

  • residue_charges (dict[str, int], optional) – Net formal charge per non-standard residue name, e.g. {'LIG': -1, 'ATP': -4}. Residues not listed default to 0 with a warning.

  • verbose (int) – Verbosity level (0 = silent, 1 = informational, 2 = debug).

name: str = 'amber'
__init__(model=None, cutoff=5.0, normalize_by_atoms=True, residue_charges=None, gaff2_files=None, verbose=0)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute AMBER14 energy for current model coordinates.

Returns:

Scalar energy in kJ/mol (or kJ/mol/atom if normalize_by_atoms). Gradient flows to model.xyz via OpenMM analytical forces.

Return type:

torch.Tensor

stats()[source]

Return target statistics for the logging pipeline.

class torchref.refinement.targets.OccupancyFloorDiagnostic(model_dark, model_light, grid_spacing=0.5, negative_threshold=-0.5)[source]

Bases: object

Diagnostic tool to estimate activation fraction floor from electron density.

Analyzes the electron density of the light/refined model and checks for unphysical negative density, which indicates the activation fraction is too small.

Parameters:
  • model_dark (ModelFT) – The dark/ground state model.

  • model_light (ModelFT) – The light/excited state model (the refined one, not MixedModel).

  • grid_spacing (float, optional) – Grid spacing in Angstroms for density calculation. Default is 0.5.

  • negative_threshold (float, optional) – Threshold below which density is considered “significantly negative”. Default is -0.5 (in sigma units after normalization).

Examples

Basic usage:

diagnostic = OccupancyFloorDiagnostic(model_dark, model_light_refine)
result = diagnostic.analyze()
print(f"Estimated alpha floor: {result['alpha_floor']:.3f}")
__init__(model_dark, model_light, grid_spacing=0.5, negative_threshold=-0.5)[source]
compute_density_at_positions(model, positions, hkl)[source]

Compute electron density at specific positions using Fourier summation.

This is a simplified calculation that sums F_calc * exp(2πi * h·r).

Parameters:
  • model (ModelFT) – Model to compute density from.

  • positions (torch.Tensor) – Positions in fractional coordinates, shape (N, 3).

  • hkl (torch.Tensor) – Miller indices, shape (M, 3).

Returns:

Electron density values at each position, shape (N,).

Return type:

torch.Tensor

analyze_at_dark_positions(hkl, atom_mask=None)[source]

Analyze light model density at dark atom positions.

Parameters:
  • hkl (torch.Tensor) – Miller indices for Fourier calculation.

  • atom_mask (torch.Tensor, optional) – Boolean mask selecting which atoms to analyze (e.g., waters only).

Returns:

Dictionary with analysis results including: - ‘rho_dark’: Dark model density at atom positions - ‘rho_light’: Light model density at atom positions - ‘rho_ratio’: ρ_light / ρ_dark (should be ≥ 0) - ‘negative_mask’: Boolean mask of atoms with negative light density - ‘alpha_floor’: Estimated lower bound on activation fraction - ‘worst_atoms’: Indices of atoms with most negative density

Return type:

dict

estimate_alpha_floor_from_difference_map(hkl, delta_F_obs, sigma_diff, n_peaks=10, sigma_cutoff=3.0)[source]

Estimate alpha floor from significant negative peaks in difference map.

For each significant negative peak in the difference map, estimate the minimum α that could produce that peak without requiring negative density in the light state.

Parameters:
  • hkl (torch.Tensor) – Miller indices.

  • delta_F_obs (torch.Tensor) – Observed difference amplitudes (can be negative).

  • sigma_diff (torch.Tensor) – Uncertainties on difference amplitudes.

  • n_peaks (int, optional) – Number of peaks to analyze. Default is 10.

  • sigma_cutoff (float, optional) – Minimum significance (|ΔF|/σ) for peaks. Default is 3.0.

Returns:

Dictionary with alpha floor estimates.

Return type:

dict

class torchref.refinement.targets.NegativeDensityPenalty(mixed_model, model_dark, hkl, atom_mask=None, check_grid=False)[source]

Bases: DeviceMixin, Module

Loss term that penalizes negative electron density in the MIXED model.

This provides a soft constraint that prevents the activation fraction from being too small (which would require unphysical negative density).

The key insight: the MIXED state (not pure light) should have non-negative density everywhere. If α is too small and atoms have moved, the mixed model might predict negative density at some positions, which is unphysical.

Parameters:
  • mixed_model (MixedModel) – The mixed model (combines dark and light states with fractions).

  • model_dark (ModelFT) – The dark/ground state model (provides reference positions to check).

  • hkl (torch.Tensor) – Miller indices for density calculation.

  • atom_mask (torch.Tensor, optional) – Mask selecting which atoms to monitor.

  • check_grid (bool, optional) – If True, also check density on a grid (more thorough but slower). Default is False.

__init__(mixed_model, model_dark, hkl, atom_mask=None, check_grid=False)[source]
forward()[source]

Compute penalty for negative density in mixed model.

Returns:

Scalar penalty value (0 if no negative density).

Return type:

torch.Tensor

class torchref.refinement.targets.DisplacementRegularizer(model_light, model_dark, atom_mask=None, max_displacement=2.0)[source]

Bases: DeviceMixin, Module

Regularizer that penalizes large atomic displacements from reference structure.

This directly breaks the α-δF degeneracy by favoring solutions where atoms haven’t moved too far from the dark structure, which implies larger α.

The loss is: mean((xyz_light - xyz_dark)²)

Parameters:
  • model_light (ModelFT) – The light model being refined.

  • model_dark (ModelFT) – The dark reference model (frozen).

  • atom_mask (torch.Tensor, optional) – Boolean mask selecting which atoms to include.

  • max_displacement (float, optional) – Maximum allowed displacement in Angstroms. Displacements beyond this are penalized quadratically. Default is 2.0 Å.

__init__(model_light, model_dark, atom_mask=None, max_displacement=2.0)[source]
forward()[source]

Compute displacement penalty.

Returns:

Mean squared displacement penalty.

Return type:

torch.Tensor

class torchref.refinement.targets.DifferenceAmplitudeRegularizer(dataset_collection, mixed_model, model_dark)[source]

Bases: DeviceMixin, Module

Regularizer that encourages consistency between α and difference amplitudes.

The key insight: the ratio of calculated to observed difference amplitudes should be consistent. If α is too small, the model compensates by making larger structural changes, which changes this ratio in a detectable way.

This regularizer penalizes deviations from the expected relationship:

|ΔF_calc||ΔF_obs|

When α is correct and the structure is correct, these should match. When α is too small and structure has moved too far, the pattern of |ΔF_calc| vs |ΔF_obs| will be distorted.

Parameters:
  • dataset_collection (DatasetCollection) – Collection with ‘dark’ and ‘light’ datasets.

  • mixed_model (MixedModel) – The mixed model being refined.

  • model_dark (ModelFT) – The dark reference model.

__init__(dataset_collection, mixed_model, model_dark)[source]
property hkl
forward()[source]

Compute regularization loss.

Penalizes the variance in the ratio |ΔF_calc|/|ΔF_obs|. If α and structure are correct, this ratio should be ~1 everywhere. If α is wrong, this ratio will have high variance.

class torchref.refinement.targets.SampledMLPhaseTarget(data=None, model=None, scaler=None, phi_ref=None, n_samples=32, sigma_model_log=0.15, use_analytical=False, use_antithetic=True, use_work_set=True, verbose=0)[source]

Bases: XrayTarget

Phase-aware ML target using reparameterized sampling.

Computes E[|F_obs*exp(i*phi) - F_calc|^2] where phi ~ N(phi_ref, sigma_phi^2) with sigma_phi derived from amplitude errors and discrepancy.

Uses French-Wilson posteriors for amplitude estimation and supports both Monte Carlo sampling and analytical evaluation.

Parameters:
  • data (ReflectionData) – Reference to the ReflectionData object.

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation.

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • phi_ref (torch.Tensor, optional) – Reference phases (e.g., from dark state). If None, uses phi_calc.

  • n_samples (int, optional) – Number of MC samples. Default is 32.

  • sigma_model_log (float, optional) – Model error in log(I) space (~R_work). Default is 0.15.

  • use_analytical (bool, optional) – Use closed-form instead of MC sampling. Default is False.

  • use_antithetic (bool, optional) – Use antithetic sampling for variance reduction. Default is True.

  • use_work_set (bool, optional) – If True, compute loss on work set. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

name

Target name for LossState registration.

Type:

str

Examples

Basic usage with model:

target = SampledMLPhaseTarget(
    data=reflection_data,
    model=model,
    scaler=scaler,
    n_samples=32,
)
loss = target()  # Computes F_calc internally

With pre-computed F_calc:

target = SampledMLPhaseTarget(data=reflection_data)
loss = target(fcalc=F_calc_precomputed)

With reference phases from dark state:

target = SampledMLPhaseTarget(
    data=light_data,
    model=light_model,
    phi_ref=torch.angle(F_dark_calc),
)
__init__(data=None, model=None, scaler=None, phi_ref=None, n_samples=32, sigma_model_log=0.15, use_analytical=False, use_antithetic=True, use_work_set=True, verbose=0)[source]

Initialize X-ray target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, fcalc must be provided to forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • use_work_set (bool, optional) – If True, compute loss on work set; if False, on test set. Default is True.

  • sigma_mode (str, optional) –

    Which sigma to use in the likelihood. Options:

    • 'raw' (default): use the raw experimental sigmas from the data file. Empirically gives the best Rfree across the mid-resolution regime (1.5-3.0 A) when paired with appropriate group weights.

    • 'effective': use per-shell effective sigmas estimated from scaling residuals (capped SIGMAA-style correction). Opt-in for high-resolution refinement (< 1.5 A) or datasets with known sigma miscalibration. Note: Scaler.estimate_sigma_eff is always called so the estimates are available regardless of which mode the target uses.

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'xray_sampled_ml'
property n_samples: int

Get number of MC samples.

property sigma_model_log: float

Get model error in log(I) space.

property use_analytical: bool

Get whether to use analytical form.

property use_antithetic: bool

Get whether to use antithetic sampling.

french_wilson_moments(I_obs, sigma_I, Sigma_wilson=None)[source]

Compute posterior mean and variance of |F_true| given I_obs.

Properly handles negative and weak intensities using numerical integration over a grid.

Parameters:
Returns:

  • F_mean (torch.Tensor) – Posterior mean of |F|.

  • F_var (torch.Tensor) – Posterior variance of |F|.

Return type:

Tuple[Tensor, Tensor]

compute_sigma_phi(F_obs, sigma_F_obs, F_calc_amp)[source]

Compute phase uncertainty from amplitude uncertainties and discrepancy.

The phase uncertainty has three components: 1. Measurement uncertainty: sigma_F_obs / |F_obs| 2. Model uncertainty: sigma_model_log (multiplicative) 3. Excess from amplitude discrepancy beyond expected

Parameters:
Returns:

sigma_phi – Phase uncertainty in radians.

Return type:

torch.Tensor

forward(fcalc=None, recalc=True)[source]

Compute phase-aware ML loss.

Parameters:
  • fcalc (torch.Tensor, optional) – Pre-computed complex structure factors. If provided, uses these instead of computing from model.

  • recalc (bool, optional) – Force recalculation if True. Default is True.

Returns:

Mean weighted loss value.

Return type:

torch.Tensor

stats(fcalc=None)[source]

Get statistics for this target.

Parameters:

fcalc (torch.Tensor, optional) – Pre-computed structure factors.

Returns:

Statistics dict with StatEntry values containing verbosity levels.

Return type:

dict

class torchref.refinement.targets.SampledMLDifferenceTarget(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, n_samples=32, sigma_model_log=0.15, use_work_set=True, verbose=0)[source]

Bases: Target

Phase-aware difference target for two-dataset refinement.

Uses dark state phases as reference, with phase uncertainty informed by amplitude changes between states. Jointly refines against both dark and light datasets.

Parameters:
  • dataset_collection (DatasetCollection) – Collection containing ‘dark’ and ‘light’ datasets.

  • model_light (ModelFT or MixedModel) – Model for the light/excited state.

  • model_dark (ModelFT) – Model for the dark/ground state.

  • scaler_light (Scaler, optional) – Scaler for light state F_calc.

  • scaler_dark (Scaler, optional) – Scaler for dark state F_calc.

  • n_samples (int, optional) – Number of MC samples. Default is 32.

  • sigma_model_log (float, optional) – Model error in log(I) space. Default is 0.15.

  • use_work_set (bool, optional) – If True, compute loss on work set. Default is True.

  • verbose (int, optional) – Verbosity level. Default is 0.

Examples

Basic usage:

target = SampledMLDifferenceTarget(
    dataset_collection=collection,
    model_light=mixed_model,
    model_dark=model_dark,
    n_samples=32,
)
loss = target()
name: str = 'sampled_ml_difference'
__init__(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, n_samples=32, sigma_model_log=0.15, use_work_set=True, verbose=0)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

property hkl: Tensor

Common HKL indices.

property n_samples: int

Get number of MC samples.

property sigma_model_log: float

Get model error in log(I) space.

forward(fcalc_light=None, fcalc_dark=None, recalc=True)[source]

Compute phase-aware difference loss.

Jointly refines against dark and light datasets using dark phases as reference. Phase uncertainty increases for reflections with large amplitude changes between states.

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

  • recalc (bool, optional) – Force recalculation if True. Default is True.

Returns:

Combined loss for both datasets.

Return type:

torch.Tensor

stats(fcalc_light=None, fcalc_dark=None)[source]

Get statistics for difference refinement.

Parameters:
  • fcalc_light (torch.Tensor, optional) – Pre-computed light state structure factors.

  • fcalc_dark (torch.Tensor, optional) – Pre-computed dark state structure factors.

Returns:

Statistics dict with StatEntry values.

Return type:

dict

torchref.refinement.targets.create_sampled_ml_target(data=None, model=None, scaler=None, phi_ref=None, n_samples=32, sigma_model_log=0.15, use_analytical=False, use_work_set=True, verbose=0)[source]

Factory function to create SampledMLPhaseTarget.

See SampledMLPhaseTarget for parameter documentation.

Returns:

Configured target instance.

Return type:

SampledMLPhaseTarget

torchref.refinement.targets.create_sampled_ml_difference_target(dataset_collection, model_light=None, model_dark=None, scaler_light=None, scaler_dark=None, n_samples=32, sigma_model_log=0.15, use_work_set=True, verbose=0)[source]

Factory function to create SampledMLDifferenceTarget.

See SampledMLDifferenceTarget for parameter documentation.

Returns:

Configured target instance.

Return type:

SampledMLDifferenceTarget

class torchref.refinement.targets.RealSpaceTarget(data=None, model=None, scaler=None, map_type='2mFo-DFc', mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0, target_value=0.0, sigma=0.5)[source]

Bases: DataTarget

Base class for real-space electron density targets.

Inherits from DataTarget to get model, data, and scaler references. Provides common infrastructure for computing observed maps, model density, and molecular masks used by the concrete subclasses.

Gradient Flow Design

  • Model density: gradients flow through Fcalc -> grid -> IFFT -> density

  • Observed map (2mFo-DFc): phases and |Fcalc| detached, no gradients

  • Observed map (Fo-Fc): |Fcalc| retains gradients, phases detached

  • Molecular mask: boolean, no gradients

param data:

Observed reflection data.

type data:

ReflectionData

param model:

Model for computing Fcalc.

type model:

ModelFT

param scaler:

Scaler for Fcalc (applied before map coefficient computation).

type scaler:

Scaler, optional

param map_type:

"2mFo-DFc" or "Fo-Fc".

type map_type:

str

param mask_solvent:

Whether to apply molecular mask. Default True.

type mask_solvent:

bool

param solvent_radius:

Probe radius for mask dilation in Angstroms. Default 1.1.

type solvent_radius:

float

param erosion_radius:

Radius for mask erosion in Angstroms. Default 0.9.

type erosion_radius:

float

param verbose:

Verbosity level. Default 0.

type verbose:

int

param target_value:

Target value for loss. Default 0.0.

type target_value:

float

param sigma:

Sigma for weighting. Default 0.5.

type sigma:

float

VALID_MAP_TYPES = ('2mFo-DFc', 'Fo-Fc')
__init__(data=None, model=None, scaler=None, map_type='2mFo-DFc', mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0, target_value=0.0, sigma=0.5)[source]

Initialize data target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • verbose (int, optional) – Verbosity level. Default is 0.

update_mask()[source]

Explicitly recompute the molecular mask.

class torchref.refinement.targets.RealSpaceCorrelationTarget(data=None, model=None, scaler=None, mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0)[source]

Bases: RealSpaceTarget

Real-space correlation coefficient (RSCC) target.

Computes RSCC between a 2mFo-DFc observed map and Fcalc model density within the molecular mask. The loss is 1 - RSCC.

The observed map uses detached model phases and amplitudes, so gradients flow only through the model density side.

Parameters:
  • data (ReflectionData) – Observed reflection data.

  • model (ModelFT) – Model for computing Fcalc.

  • scaler (Scaler, optional) – Scaler for Fcalc.

  • mask_solvent (bool) – Whether to apply molecular mask. Default True.

  • solvent_radius (float) – Probe radius for mask in Angstroms. Default 1.1.

  • erosion_radius (float) – Radius for mask erosion in Angstroms. Default 0.9.

  • verbose (int) – Verbosity level. Default 0.

name: str = 'realspace/correlation'
__init__(data=None, model=None, scaler=None, mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0)[source]

Initialize data target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute 1 - RSCC loss.

Returns:

Scalar loss value (1 - RSCC).

Return type:

torch.Tensor

stats()[source]

Get statistics for the correlation target.

Returns:

Dictionary with loss, rscc, and n_voxels.

Return type:

dict

class torchref.refinement.targets.RealSpaceDifferenceTarget(data=None, model=None, scaler=None, mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0)[source]

Bases: RealSpaceTarget

Real-space Fo-Fc difference density target.

Computes the mean squared Fo-Fc difference density within the molecular mask. This penalizes unexplained features in the difference map.

The |Fcalc| component retains gradients while phases are detached, providing direct gradient signal for model refinement.

Parameters:
  • data (ReflectionData) – Observed reflection data.

  • model (ModelFT) – Model for computing Fcalc.

  • scaler (Scaler, optional) – Scaler for Fcalc.

  • mask_solvent (bool) – Whether to apply molecular mask. Default True.

  • solvent_radius (float) – Probe radius for mask in Angstroms. Default 1.1.

  • erosion_radius (float) – Radius for mask erosion in Angstroms. Default 0.9.

  • verbose (int) – Verbosity level. Default 0.

name: str = 'realspace/difference'
__init__(data=None, model=None, scaler=None, mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0)[source]

Initialize data target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute mean squared Fo-Fc difference density.

Returns:

Scalar loss value (mean squared difference density).

Return type:

torch.Tensor

stats()[source]

Get statistics for the difference target.

Returns:

Dictionary with loss, rms_diff, mean_abs_diff, peak values, and n_voxels.

Return type:

dict

class torchref.refinement.targets.RealSpaceExtrapolatedTarget(dataset_collection, model_dark=None, model_light=None, model_mixed=None, scaler_dark=None, scaler_mixed=None, scaler_light=None, mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0)[source]

Bases: RealSpaceTarget

Real-space correlation target using extrapolated pure-light density.

Computes the RSCC between an extrapolated pure-light electron density map and the light model’s Fcalc density within the molecular mask. The loss is 1 - RSCC.

The extrapolation combines observed dark/light amplitudes with model-derived phases:

F_extra = (F_light * exp(i*phi_mixed) - w_dark * F_dark * exp(i*phi_dark)) / w_light

where w_dark, w_light are population fractions from the mixed model.

Parameters:
  • dataset_collection (DatasetCollection) – Collection containing ‘dark’ and ‘light’ datasets (aligned HKL).

  • model_dark (ModelFT) – Dark-state model (for dark phases).

  • model_light (ModelFT) – Light-state model (gradients flow through this model’s density).

  • model_mixed (MixedModel) – Mixed model (for mixed-state phases and population fractions).

  • scaler_dark (Scaler, optional) – Scaler for dark Fcalc.

  • scaler_mixed (Scaler, optional) – Scaler for mixed Fcalc.

  • scaler_light (Scaler, optional) – Scaler for light model Fcalc (model density side).

  • mask_solvent (bool, optional) – Whether to apply molecular mask. Default True.

  • solvent_radius (float, optional) – Probe radius for mask in Angstroms. Default 1.1.

  • erosion_radius (float, optional) – Radius for mask erosion in Angstroms. Default 0.9.

  • verbose (int, optional) – Verbosity level. Default 0.

name: str = 'realspace_extrapolated'
__init__(dataset_collection, model_dark=None, model_light=None, model_mixed=None, scaler_dark=None, scaler_mixed=None, scaler_light=None, mask_solvent=True, solvent_radius=1.1, erosion_radius=0.9, verbose=0)[source]

Initialize data target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute 1 - RSCC between extrapolated map and model density.

Returns:

Scalar loss value (1 - RSCC).

Return type:

torch.Tensor

stats()[source]

Get statistics for the extrapolated real-space target.

Returns:

Dictionary with loss, rscc, and n_voxels.

Return type:

dict

class torchref.refinement.targets.CoordinateSimilarityTarget(model_dark=None, model_light=None, alpha=2.0, verbose=0)[source]

Bases: Target

Spike-and-slab similarity restraint between dark and light models.

For each atom, two hypotheses are considered:

  • Static (prob 1-p): atom did not move, displacement is noise

  • Moved (prob p): atom genuinely displaced

The loss is the negative log marginal likelihood:

L(d) = -logsumexp(-d^2/(2*sigma^2) + alpha, 0)

where d = ||xyz_light - xyz_dark|| and sigma = sqrt(B / 8*pi^2) is the per-atom coordinate uncertainty from B-factors.

Gradient: d/sigma^2 * sigmoid(-d^2/(2*sigma^2) + alpha) This is an L2 restraint weighted by the posterior probability that the atom is static.

Behavior: - d << sigma: ~0.5 * d^2 / sigma^2 (quadratic, tight restraint) - d >> sigma: plateaus completely (no penalty for genuine moves) - Crossover at d ~ sigma * sqrt(2*alpha)

Parameters:
  • model_dark (Model) – Dark (ground state) model. B-factors and coordinates are detached.

  • model_light (Model) – Light (excited state) model. Coordinates carry gradients.

  • alpha (float, optional) – Log prior odds of the static hypothesis. Higher values mean stronger denoising. Default is 2.0 (crossover at ~2*sigma).

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'similarity'
__init__(model_dark=None, model_light=None, alpha=2.0, verbose=0)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

property model_dark: Model

Get dark model.

property model_light: Model

Get light model.

property alpha: float

Get alpha as float.

forward()[source]

Compute spike-and-slab similarity loss.

Returns:

Scalar mean loss over all matched atom pairs.

Return type:

torch.Tensor

stats()[source]

Get similarity restraint statistics.

Subpackages

Submodules