Target Functions

Target functions (loss functions) drive the refinement optimization. TorchRef provides several built-in targets and makes it easy to create custom ones.

Targets reference the refinement object to access model parameters, reflection data, and restraints.

Targets are registered in a LossState object, which manages

  • Target instances

  • Weights for each target

  • Metadata for refinement monitoring

Targets can be grouped into composite targets for geometry and ADP restraints or custom combinations.

Built-in Targets

X-ray Targets

  • Least Squares: Traditional least squares refinement

  • Maximum Likelihood: ML-based refinement with proper error model. This is the default.

  • Gaussian NLL: Gaussian negative log-likelihood

Geometry Targets

  • Bond restraints: Penalize deviations from ideal bond lengths

  • Angle restraints: Penalize deviations from ideal bond angles

  • Torsion restraints: Penalize deviations from ideal torsion angles

  • Planarity restraints: Penalize out-of-plane deviations

  • Chirality restraints: Maintain correct stereochemistry

  • VDW restraints: Prevent atomic overlaps

ADP Targets

  • SIMU: Similar ADPs for nearby atoms

  • locality restraint: Smooth ADP variations among neighbors

  • DELU: Rigid bond restraints

  • Population restraint: Restrain ADPs towards a target distribution

Stats functionality

Each target implements a stats() method returning a dictionary of metrics. These metrics are collected during refinement for monitoring and logging. They supply useful information about the model.

# Collect stats from each target
xray_stats = xray_target.stats()
geom_stats = geom_target.stats()
adp_stats = adp_target.stats()

# Example: print R-factors from X-ray target
print(f"R-work: {xray_stats['rwork']}, R-free: {xray_stats['rfree']}")

Using Targets

from torchref.refinement.targets import (
    create_xray_target,
    TotalGeometryTarget,
    TotalADPTarget
)

# Create targets
xray_target = create_xray_target(data, model, target_type='ml')
geom_target = TotalGeometryTarget(model)
adp_target = TotalADPTarget(model)

# Compute losses
xray_loss = xray_target()
geom_loss = geom_target()
adp_loss = adp_target()

Custom Targets

Create custom targets by subclassing Target:

from torchref.refinement.targets import Target
import torch

class EntropyRegularization(Target):
    """Entropy regularization for B-factors."""

    def __init__(self, model, weight=0.01):
        super().__init__(model)

    def forward(self):
        b_factors = self.model.b()
        # Entropy-based regularization
        entropy = -torch.sum(b_factors * torch.log(b_factors + 1e-8))
        return entropy

The key advantage: you only define the forward pass. PyTorch’s autograd automatically computes all necessary gradients for optimization.

LossState

The LossState class manages multiple targets and their weights. It is used to track targets during refinement and supplies information for weight calculation.

from torchref.refinement import LossState

state = LossState(device=self.device)

# Register targets
state.register_target('xray', xray_target)
state.register_target('geometry', geom_target)
state.register_target('adp', adp_target)

# Set weights
state.set_weight('xray', 1.0)
state.set_weight('geometry', 0.1)
state.set_weight('adp', 0.1)

# Compute total loss
total_loss = state.aggregate()

# Compute backward pass
total_loss.backward()


# Example optimization step with LBFGS using the LossState
# This prunes loss evaluation and keeps track of what things need to be recomputed automatically
from torch.optim import LBFGS
optimizer = LBFGS(model.parameters(), lr=1.0, max_iter=100)
state.run(optimizer, n_steps=1) # equivalent to state.step(optimizer)