torchref.refinement.targets.base module

Target Functions for Crystallographic Refinement

This module provides target (loss) functions for crystallographic refinement. Each target is instantiated once with a reference to the refinement object, then evaluated on each iteration by calling the target.

Target Types: - X-ray targets: Least Squares, Maximum Likelihood, Gaussian NLL - Geometry restraint targets: Bonds, Angles, Torsions - ADP restraint targets: Similarity (SIMU), Rigid Bond (DELU)

LossState Integration: - Targets can optionally receive a LossState and add their loss to it

class torchref.refinement.targets.base.Target(verbose=0, **kwargs)[source]

Bases: DeviceMixin, Module

Abstract base class for all target functions.

All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.

Supports empty initialization for state_dict loading:

target = Target()  # Creates empty shell
target.load_state_dict(torch.load('target.pt'))
LossState Integration:

Targets can work with LossState for the new pipeline:

state = target.add_to_state(state)  # Adds loss to state
Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

name

Unique name for this target (used as loss key in LossState).

Type:

str

verbose

Verbosity level.

Type:

int

name: str = 'base_target'
__init__(verbose=0, **kwargs)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

forward()[source]

Compute and return the loss. Override in subclasses.

add_to_state(state)[source]

Compute loss and add it to the LossState.

This method enables the new LossState pipeline pattern where targets receive a state object, compute their loss, add it to the state, and return the state for chaining.

Parameters:

state (LossState) – Current loss state with computed data.

Returns:

State with this target’s loss added.

Return type:

LossState

maintenance()[source]

Between-step housekeeping hook (no-op by default).

LossState calls this on every registered target after each successful outer optimizer step returns. Targets override this to rebuild stale internal state (VDW pair lists, solvent masks, etc.) based on how far parameters have drifted since the last refresh.

Contract

  • Must be idempotent: calling it multiple times in a row on an unchanged model should not mutate the target.

  • Fast path first: cheap staleness check up front, expensive rebuild only when strictly necessary. LossState calls this every outer step — the happy-path cost is paid every time.

  • Must not raise on routine drift. If a rebuild fails, let the exception propagate — that’s a real bug.

class torchref.refinement.targets.base.ModelTarget(model=None, verbose=0, **kwargs)[source]

Bases: Target

Base class for targets that only need a Model reference.

This class provides a simpler interface for geometry and ADP targets that don’t need access to reflection data or refinement machinery. Targets inherit from this class when they only need the atomic model.

The model is registered as a proper submodule, allowing PyTorch to handle device movement and state_dict operations automatically.

Parameters:
  • model (Model, optional) – Reference to the Model object.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • target_value (float, optional) – Target value for this loss. Default is 0.0.

  • sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.

name

Unique name for this target (used as loss key in LossState).

Type:

str

_model

Reference to the model object (registered as submodule).

Type:

Model

verbose

Verbosity level.

Type:

int

name: str = 'model_target'
__init__(model=None, verbose=0, **kwargs)[source]

Initialize model target.

Parameters:
  • model (Model, optional) – Reference to the Model object (optional for empty init).

  • verbose (int, optional) – Verbosity level. Default is 0.

property model: Model

Access the model object.

property restraints

Access model’s restraints (built lazily on first access).

class torchref.refinement.targets.base.DataTarget(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]

Bases: Target

Base class for targets that need ReflectionData and optionally Model/Scaler.

This class provides a flexible interface for X-ray targets that can work in two modes:

  1. With Model: Computes F_calc from the model on each forward pass

  2. Without Model: Uses pre-computed F_calc passed directly

This decoupling allows targets to be used for: - Standard refinement (with model) - Analysis/scoring of pre-computed structure factors (without model) - Testing and validation workflows

All objects (model, data, scaler) are registered as proper submodules, allowing PyTorch to handle device movement and state_dict operations.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to a Model object for F_calc computation. If None, F_calc must be provided to forward().

  • scaler (Scaler, optional) – Reference to the Scaler object for scaling F_calc.

  • verbose (int, optional) – Verbosity level. Default is 0.

  • target_value (float, optional) – Target value for this loss. Default is 0.0.

  • sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.

name

Unique name for this target (used as loss key in LossState).

Type:

str

_model

Reference to the model object (registered as submodule).

Type:

Model

_data

Reference to the reflection data object (registered as submodule).

Type:

ReflectionData

_scaler

Reference to the scaler object (registered as submodule).

Type:

Scaler

verbose

Verbosity level.

Type:

int

name: str = 'data_target'
__init__(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]

Initialize data target.

Parameters:
  • data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().

  • model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().

  • scaler (Scaler, optional) – Reference to the Scaler object.

  • verbose (int, optional) – Verbosity level. Default is 0.

property model: Model

Access the model object.

property data: ReflectionData

Access the reflection data object.

property scaler: Scaler

Access the scaler object.

property has_model: bool

Check if a model is available for F_calc computation.

get_fcalc(hkl=None, recalc=False)[source]

Compute structure factors from model.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.

  • recalc (bool, optional) – Force recalculation. Default is False.

Returns:

Complex structure factors.

Return type:

torch.Tensor

Raises:

RuntimeError – If no model is set.

get_fcalc_scaled(hkl=None, recalc=False, fcalc=None)[source]

Compute or scale structure factors.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.

  • recalc (bool, optional) – Force recalculation. Default is False.

  • fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.

Returns:

Scaled complex structure factors.

Return type:

torch.Tensor

get_F_calc_scaled(hkl=None, recalc=False, fcalc=None)[source]

Compute scaled structure factor amplitudes.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.

  • recalc (bool, optional) – Force recalculation. Default is False.

  • fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.

Returns:

Scaled structure factor amplitudes |F_calc|.

Return type:

torch.Tensor

get_rfactor()[source]

Compute R-factors using scaler.

Returns:

(R_work, R_free) values.

Return type:

tuple

Raises:

RuntimeError – If no scaler is set.

torchref.refinement.targets.base.gaussian_nll(deviations, sigmas)[source]

Compute Gaussian negative log-likelihood.

NLL = 0.5 * ((x - μ) / σ)² + log(σ) + 0.5 * log(2π)

Parameters:
Returns:

Tensor of NLL values (same shape as input).

Return type:

torch.Tensor

torchref.refinement.targets.base.von_mises_nll(deviations_rad, sigmas_deg)[source]

Compute von Mises negative log-likelihood for angular data.

NLL = -κ*cos(θ) + log(I₀(κ)) + log(2π) where κ = 1/σ²

Parameters:
  • deviations_rad (torch.Tensor) – Angular deviations in radians.

  • sigmas_deg (torch.Tensor) – Standard deviations in degrees.

Returns:

Tensor of NLL values (same shape as input).

Return type:

torch.Tensor

torchref.refinement.targets.base.adp_similarity_nll(adp_diffs, sigma=2.0)[source]

Compute ADP similarity NLL (SIMU restraint).

Parameters:
  • adp_diffs (torch.Tensor) – ADP differences between bonded atoms.

  • sigma (float, optional) – Target standard deviation. Default is 2.0 Ų.

Returns:

Tensor of NLL values (same shape as input).

Return type:

torch.Tensor

torchref.refinement.targets.base.detach_phases(fcalc)[source]

Extract phases from complex structure factors with gradient detachment.

Parameters:

fcalc (torch.Tensor) – Complex structure factors.

Returns:

Detached phase angles in radians.

Return type:

torch.Tensor