"""
Target Functions for Crystallographic Refinement
This module provides target (loss) functions for crystallographic refinement.
Each target is instantiated once with a reference to the refinement object,
then evaluated on each iteration by calling the target.
Target Types:
- X-ray targets: Least Squares, Maximum Likelihood, Gaussian NLL
- Geometry restraint targets: Bonds, Angles, Torsions
- ADP restraint targets: Similarity (SIMU), Rigid Bond (DELU)
LossState Integration:
- Targets can optionally receive a LossState and add their loss to it
"""
from typing import TYPE_CHECKING, Dict, Tuple
import numpy as np
import torch
from torch import nn
from torch.special import i0
from torchref.utils.device_mixin import DeviceMixin
from torchref.utils.stats import (
VERBOSITY_DEBUG,
VERBOSITY_DETAILED,
VERBOSITY_STANDARD,
StatEntry,
stat,
)
if TYPE_CHECKING:
from torchref.io import ReflectionData
from torchref.io.datasets.collection import DatasetCollection
from torchref.model.model import Model
from torchref.model.model_ft import ModelFT
from torchref.refinement.loss_state import LossState
from torchref.scaling.scaler_base import Scaler
# =============================================================================
# Base Target Class
# =============================================================================
[docs]
class Target(DeviceMixin, nn.Module):
"""
Abstract base class for all target functions.
All tunable parameters should be registered as buffers using register_buffer()
so they can be accessed/modified via state_dict notation.
Supports empty initialization for state_dict loading::
target = Target() # Creates empty shell
target.load_state_dict(torch.load('target.pt'))
LossState Integration:
Targets can work with LossState for the new pipeline::
state = target.add_to_state(state) # Adds loss to state
Parameters
----------
verbose : int, optional
Verbosity level. Default is 0.
Attributes
----------
name : str
Unique name for this target (used as loss key in LossState).
verbose : int
Verbosity level.
"""
# Class attribute: unique name for this target type
# Subclasses should override this
name: str = "base_target"
[docs]
def __init__(
self,
verbose: int = 0,
**kwargs,
):
"""
Initialize target.
Parameters
----------
verbose : int, optional
Verbosity level. Default is 0.
"""
super().__init__()
self.verbose = verbose
[docs]
def forward(self) -> torch.Tensor:
"""Compute and return the loss. Override in subclasses."""
raise NotImplementedError
[docs]
def add_to_state(self, state: "LossState") -> "LossState":
"""
Compute loss and add it to the LossState.
This method enables the new LossState pipeline pattern where targets
receive a state object, compute their loss, add it to the state,
and return the state for chaining.
Parameters
----------
state : LossState
Current loss state with computed data.
Returns
-------
LossState
State with this target's loss added.
"""
loss = self.forward()
state.add_loss(self.name, loss)
return state
[docs]
def maintenance(self) -> None:
"""Between-step housekeeping hook (no-op by default).
:class:`~torchref.refinement.loss_state.LossState` calls this on
every registered target after each successful outer optimizer
step returns. Targets override this to rebuild stale internal
state (VDW pair lists, solvent masks, etc.) based on how far
parameters have drifted since the last refresh.
Contract
--------
- Must be idempotent: calling it multiple times in a row on an
unchanged model should not mutate the target.
- Fast path first: cheap staleness check up front, expensive
rebuild only when strictly necessary. ``LossState`` calls
this every outer step — the happy-path cost is paid every
time.
- Must not raise on routine drift. If a rebuild fails, let the
exception propagate — that's a real bug.
"""
pass
# =============================================================================
# Model-Only Target Base Class
# =============================================================================
[docs]
class ModelTarget(Target):
"""
Base class for targets that only need a Model reference.
This class provides a simpler interface for geometry and ADP targets
that don't need access to reflection data or refinement machinery.
Targets inherit from this class when they only need the atomic model.
The model is registered as a proper submodule, allowing PyTorch to
handle device movement and state_dict operations automatically.
Parameters
----------
model : Model, optional
Reference to the Model object.
verbose : int, optional
Verbosity level. Default is 0.
target_value : float, optional
Target value for this loss. Default is 0.0.
sigma : float, optional
Sigma parameter for weighting. Default is 0.5.
Attributes
----------
name : str
Unique name for this target (used as loss key in LossState).
_model : Model
Reference to the model object (registered as submodule).
verbose : int
Verbosity level.
"""
name: str = "model_target"
[docs]
def __init__(
self,
model: "Model" = None,
verbose: int = 0,
**kwargs,
):
"""
Initialize model target.
Parameters
----------
model : Model, optional
Reference to the Model object (optional for empty init).
verbose : int, optional
Verbosity level. Default is 0.
"""
super().__init__(verbose=verbose)
# Register model as a proper submodule (not in state_dict but handles device)
# Use add_module to allow None values
self.add_module("_model", model)
@property
def model(self) -> "Model":
"""Access the model object."""
return self._model
@property
def restraints(self):
"""Access model's restraints (built lazily on first access)."""
if self._model is None:
return None
return self._model.restraints
# =============================================================================
# Data Target Base Class (for X-ray targets)
# =============================================================================
[docs]
class DataTarget(Target):
"""
Base class for targets that need ReflectionData and optionally Model/Scaler.
This class provides a flexible interface for X-ray targets that can work
in two modes:
1. With Model: Computes F_calc from the model on each forward pass
2. Without Model: Uses pre-computed F_calc passed directly
This decoupling allows targets to be used for:
- Standard refinement (with model)
- Analysis/scoring of pre-computed structure factors (without model)
- Testing and validation workflows
All objects (model, data, scaler) are registered as proper submodules,
allowing PyTorch to handle device movement and state_dict operations.
Parameters
----------
data : ReflectionData, optional
Reference to the ReflectionData object. Required for forward().
model : Model or ModelFT, optional
Reference to a Model object for F_calc computation.
If None, F_calc must be provided to forward().
scaler : Scaler, optional
Reference to the Scaler object for scaling F_calc.
verbose : int, optional
Verbosity level. Default is 0.
target_value : float, optional
Target value for this loss. Default is 0.0.
sigma : float, optional
Sigma parameter for weighting. Default is 0.5.
Attributes
----------
name : str
Unique name for this target (used as loss key in LossState).
_model : Model
Reference to the model object (registered as submodule).
_data : ReflectionData
Reference to the reflection data object (registered as submodule).
_scaler : Scaler
Reference to the scaler object (registered as submodule).
verbose : int
Verbosity level.
"""
name: str = "data_target"
[docs]
def __init__(
self,
data: "ReflectionData" = None,
model: "Model" = None,
scaler: "Scaler" = None,
verbose: int = 0,
**kwargs,
):
"""
Initialize data target.
Parameters
----------
data : ReflectionData, optional
Reference to the ReflectionData object. Required for forward().
model : Model or ModelFT, optional
Reference to Model object for F_calc computation.
If None, F_calc must be provided when calling forward().
scaler : Scaler, optional
Reference to the Scaler object.
verbose : int, optional
Verbosity level. Default is 0.
"""
super().__init__(verbose=verbose)
# Register as proper submodules (allows None values)
self.add_module("_model", model)
self._data = data
self.add_module("_scaler", scaler)
@property
def model(self) -> "Model":
"""Access the model object."""
return self._model
@property
def data(self) -> "ReflectionData":
"""Access the reflection data object."""
return self._data
@property
def scaler(self) -> "Scaler":
"""Access the scaler object."""
return self._scaler
@property
def has_model(self) -> bool:
"""Check if a model is available for F_calc computation."""
return self._model is not None
[docs]
def get_fcalc(self, hkl=None, recalc=False):
"""
Compute structure factors from model.
Parameters
----------
hkl : torch.Tensor, optional
Miller indices. If None, uses data's hkl.
recalc : bool, optional
Force recalculation. Default is False.
Returns
-------
torch.Tensor
Complex structure factors.
Raises
------
RuntimeError
If no model is set.
"""
if self._model is None:
raise RuntimeError(
"Cannot compute F_calc: no model set. "
"Either provide a model or pass fcalc directly."
)
if hkl is None:
hkl, _, _, _ = self._data()
return self._model(hkl, recalc=recalc)
[docs]
def get_fcalc_scaled(self, hkl=None, recalc=False, fcalc=None):
"""
Compute or scale structure factors.
Parameters
----------
hkl : torch.Tensor, optional
Miller indices. If None, uses data's hkl.
recalc : bool, optional
Force recalculation. Default is False.
fcalc : torch.Tensor, optional
Pre-computed structure factors. If provided, skips model computation.
Returns
-------
torch.Tensor
Scaled complex structure factors.
"""
if fcalc is None:
fcalc = self.get_fcalc(hkl, recalc=recalc)
if self._scaler is not None:
return self._scaler(fcalc)
return fcalc
[docs]
def get_F_calc_scaled(self, hkl=None, recalc=False, fcalc=None):
"""
Compute scaled structure factor amplitudes.
Parameters
----------
hkl : torch.Tensor, optional
Miller indices. If None, uses data's hkl.
recalc : bool, optional
Force recalculation. Default is False.
fcalc : torch.Tensor, optional
Pre-computed structure factors. If provided, skips model computation.
Returns
-------
torch.Tensor
Scaled structure factor amplitudes |F_calc|.
"""
return torch.abs(self.get_fcalc_scaled(hkl, recalc=recalc, fcalc=fcalc))
[docs]
def get_rfactor(self):
"""
Compute R-factors using scaler.
Returns
-------
tuple
(R_work, R_free) values.
Raises
------
RuntimeError
If no scaler is set.
"""
if self._scaler is None:
raise RuntimeError("Cannot compute R-factor: no scaler set.")
return self._scaler.rfactor()
# =============================================================================
# Utility Functions for NLL Computation
# =============================================================================
[docs]
def gaussian_nll(deviations: torch.Tensor, sigmas: torch.Tensor) -> torch.Tensor:
"""
Compute Gaussian negative log-likelihood.
NLL = 0.5 * ((x - μ) / σ)² + log(σ) + 0.5 * log(2π)
Parameters
----------
deviations : torch.Tensor
Deviations from target values (x - μ).
sigmas : torch.Tensor
Standard deviations.
Returns
-------
torch.Tensor
Tensor of NLL values (same shape as input).
"""
log_2pi = torch.log(
torch.tensor(2.0 * np.pi, device=sigmas.device, dtype=sigmas.dtype)
)
nll = 0.5 * (deviations / sigmas) ** 2 + torch.log(sigmas) + 0.5 * log_2pi
return nll
[docs]
def von_mises_nll(
deviations_rad: torch.Tensor, sigmas_deg: torch.Tensor
) -> torch.Tensor:
"""
Compute von Mises negative log-likelihood for angular data.
NLL = -κ*cos(θ) + log(I₀(κ)) + log(2π)
where κ = 1/σ²
Parameters
----------
deviations_rad : torch.Tensor
Angular deviations in radians.
sigmas_deg : torch.Tensor
Standard deviations in degrees.
Returns
-------
torch.Tensor
Tensor of NLL values (same shape as input).
"""
sigmas_rad = sigmas_deg * (np.pi / 180.0)
kappa = torch.clamp(1.0 / (sigmas_rad**2), min=1e-3, max=1e4)
log_i0_kappa = torch.zeros_like(kappa)
small_kappa_mask = kappa < 50.0
large_kappa_mask = ~small_kappa_mask
if small_kappa_mask.any():
log_i0_kappa[small_kappa_mask] = torch.log(i0(kappa[small_kappa_mask]))
if large_kappa_mask.any():
kappa_large = kappa[large_kappa_mask]
log_i0_kappa[large_kappa_mask] = kappa_large - 0.5 * torch.log(
2.0 * np.pi * kappa_large
)
log_2pi = torch.log(
torch.tensor(2.0 * np.pi, device=sigmas_deg.device, dtype=sigmas_deg.dtype)
)
log_prob = kappa * torch.cos(deviations_rad) - log_i0_kappa - log_2pi
return -log_prob
[docs]
def adp_similarity_nll(adp_diffs: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
"""
Compute ADP similarity NLL (SIMU restraint).
Parameters
----------
adp_diffs : torch.Tensor
ADP differences between bonded atoms.
sigma : float, optional
Target standard deviation. Default is 2.0 Ų.
Returns
-------
torch.Tensor
Tensor of NLL values (same shape as input).
"""
log_2pi = torch.log(
torch.tensor(2.0 * np.pi, device=adp_diffs.device, dtype=adp_diffs.dtype)
)
nll = 0.5 * (adp_diffs / sigma) ** 2 + np.log(sigma) + 0.5 * log_2pi
return nll
[docs]
def detach_phases(fcalc: torch.Tensor) -> torch.Tensor:
"""
Extract phases from complex structure factors with gradient detachment.
Parameters
----------
fcalc : torch.Tensor
Complex structure factors.
Returns
-------
torch.Tensor
Detached phase angles in radians.
"""
return torch.angle(fcalc).detach()