Target Functions
Target functions (loss functions) drive the refinement optimization. TorchRef provides several built-in targets and makes it easy to create custom ones.
Targets reference the refinement object to access model parameters, reflection data, and restraints.
Targets are registered in a LossState object, which manages
Target instances
Weights for each target
Metadata for refinement monitoring
Targets can be grouped into composite targets for geometry and ADP restraints or custom combinations.
Built-in Targets
X-ray Targets
Least Squares: Traditional least squares refinement
Maximum Likelihood: ML-based refinement with proper error model. This is the default.
Gaussian NLL: Gaussian negative log-likelihood
Geometry Targets
Bond restraints: Penalize deviations from ideal bond lengths
Angle restraints: Penalize deviations from ideal bond angles
Torsion restraints: Penalize deviations from ideal torsion angles
Planarity restraints: Penalize out-of-plane deviations
Chirality restraints: Maintain correct stereochemistry
VDW restraints: Prevent atomic overlaps
ADP Targets
SIMU: Similar ADPs for nearby atoms
locality restraint: Smooth ADP variations among neighbors
DELU: Rigid bond restraints
Population restraint: Restrain ADPs towards a target distribution
Stats functionality
Each target implements a stats() method returning a dictionary of metrics. These metrics are collected during refinement for monitoring and logging. They supply useful information about the model.
# Collect stats from each target
xray_stats = xray_target.stats()
geom_stats = geom_target.stats()
adp_stats = adp_target.stats()
# Example: print R-factors from X-ray target
print(f"R-work: {xray_stats['rwork']}, R-free: {xray_stats['rfree']}")
Using Targets
from torchref.refinement.targets import (
create_xray_target,
TotalGeometryTarget,
TotalADPTarget
)
# Create targets
xray_target = create_xray_target(data, model, target_type='ml')
geom_target = TotalGeometryTarget(model)
adp_target = TotalADPTarget(model)
# Compute losses
xray_loss = xray_target()
geom_loss = geom_target()
adp_loss = adp_target()
Custom Targets
Create custom targets by subclassing Target:
from torchref.refinement.targets import Target
import torch
class EntropyRegularization(Target):
"""Entropy regularization for B-factors."""
def __init__(self, model, weight=0.01):
super().__init__(model)
def forward(self):
b_factors = self.model.b()
# Entropy-based regularization
entropy = -torch.sum(b_factors * torch.log(b_factors + 1e-8))
return entropy
The key advantage: you only define the forward pass. PyTorch’s autograd automatically computes all necessary gradients for optimization.
LossState
The LossState class manages multiple targets and their weights.
It is used to track targets during refinement and supplies information for weight calculation.
from torchref.refinement import LossState
state = LossState(device=self.device)
# Register targets
state.register_target('xray', xray_target)
state.register_target('geometry', geom_target)
state.register_target('adp', adp_target)
# Set weights
state.set_weight('xray', 1.0)
state.set_weight('geometry', 0.1)
state.set_weight('adp', 0.1)
# Compute total loss
total_loss = state.aggregate()
# Compute backward pass
total_loss.backward()
# Example optimization step with LBFGS using the LossState
# This prunes loss evaluation and keeps track of what things need to be recomputed automatically
from torch.optim import LBFGS
optimizer = LBFGS(model.parameters(), lr=1.0, max_iter=100)
state.run(optimizer, n_steps=1) # equivalent to state.step(optimizer)