Quick Start
This guide walks you through a basic crystallographic refinement with TorchRef.
For interactive examples, see the Jupyter notebooks in example_notebooks/.
For a more detailed explanation check out these collabs:
Basic Usage - Getting started tutorial
Code Examples - Common patterns and recipes
Target Exploration - Exploring refinement targets
Structure Factor Calculation - FFT-based F_calc
Basic Refinement
The simplest way to run a refinement:
from torchref import LBFGSRefinement, ROOT_TORCHREF
# Initialize refinement with data and model
refinement = LBFGSRefinement(
data_file="example.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
# Check initial R-factors
rwork, rfree = refinement.get_rfactor()
print(f"Initial Rwork: {rwork:.3f}")
...
Initial Rwork: 0...
Loading Data
TorchRef supports multiple file formats:
MTZ files (reflection data):
from torchref import ReflectionData, ROOT_TORCHREF
data = ReflectionData()
data.load_mtz(f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz")
# Access reflection data
hkl, F, sigF, rfree_flags = data()
print(f"Number of reflections: {len(F)}")
...
Number of reflections: ...
PDB files (atomic models):
from torchref import ModelFT, ROOT_TORCHREF
model = ModelFT()
model.load_pdb(f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb")
print(f"Number of atoms: {len(model.pdb)}")
...
Number of atoms: ...
Parameter Freezing
Model parameters can be selectively frozen during refinement:
from torchref import ModelFT, ROOT_TORCHREF
model = ModelFT().load_pdb(f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb")
# Freeze/unfreeze by parameter type
model.freeze('b') # Freeze all B-factors
model.unfreeze('b') # Unfreeze B-factors
model.freeze('xyz') # Freeze coordinates
model.unfreeze('xyz') # Unfreeze coordinates
# Freeze/unfreeze by selection (phenix-style syntax)
model.freeze_selection("chain A and resseq 10:20")
print(f"Refinable xyz atoms after freezing selection: {model.xyz.refinable_params.shape[0]}")
# Unfreeze everything
model.unfreeze_selection("all")
print(f"Refinable xyz atoms after unfreezing all: {model.xyz.refinable_params.shape[0]}")
...
Refinable xyz atoms after freezing selection: ...
Refinable xyz atoms after unfreezing all: ...
Computing Structure Factors
Use the model to compute structure factors for given Miller indices:
import torch
from torchref import ModelFT, ReflectionData, Scaler, ROOT_TORCHREF
from torchref.math_functions.math_torch import get_rfactors
# Load data and model
data = ReflectionData().load_mtz(f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz")
model = ModelFT().load_pdb(f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb")
# Get reflection indices
hkl, F, sigF, rfree = data()
# Calculate structure factors
fcalc = model(hkl)
# Scale to observed data
scale = Scaler(model, data)
scale.initialize()
scale.refine_lbfgs()
scaled_fcalc = scale(fcalc)
Fcalc_abs = torch.abs(scaled_fcalc)
# Compute R-factors
rwork, rfree_val = get_rfactors(F, Fcalc_abs, rfree)
print(f"Rwork: {rwork:.3f}, Rfree: {rfree_val:.3f}")
...
Rwork: 0..., Rfree: 0...
Running Refinement
Run coordinate refinement with geometry restraints:
from torchref import LBFGSRefinement, ROOT_TORCHREF
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
# Perturb coordinates to simulate errors
refinement.model.shake_coords(0.1)
# Run xyz refinement
refinement.refine_xyz()
rwork, rfree = refinement.get_rfactor()
print(f"After refinement - Rwork: {rwork:.3f}")
...
After refinement - Rwork: 0...
Writing Output Files
Save refined structures and structure factors:
import os
import tempfile
from torchref import LBFGSRefinement, ROOT_TORCHREF
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
# Write to temporary directory for this example
with tempfile.TemporaryDirectory() as tmpdir:
pdb_path = os.path.join(tmpdir, "refined.pdb")
mtz_path = os.path.join(tmpdir, "refined.mtz")
# Write refined structure
refinement.write_out_pdb(pdb_path)
print(f"PDB written: {os.path.exists(pdb_path)}")
# Write structure factors with map coefficients
refinement.write_out_mtz(mtz_path)
print(f"MTZ written: {os.path.exists(mtz_path)}")
...
PDB written: True
...
MTZ written: True
Geometry Restraints
Access and inspect geometry restraints:
from torchref import LBFGSRefinement, ROOT_TORCHREF
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
# Access restraint counts
bonds = refinement.restraints.restraints['bond']['intra']
angles = refinement.restraints.restraints['angle']['intra']
print(f"Bond restraints: {bonds['indices'].shape[0]}")
print(f"Angle restraints: {angles['indices'].shape[0]}")
# Compute geometry losses
losses = refinement.geometry_target.target_losses()
print(f"Bond loss: {losses['bond'].item():.4f}")
print(f"Angle loss: {losses['angle'].item():.4f}")
...
Bond restraints: ...
Angle restraints: ...
Bond loss: ...
Angle loss: ...
Custom Target Functions
Define custom refinement targets with automatic gradient computation:
import torch
from torchref import LBFGSRefinement, ROOT_TORCHREF
from torchref.refinement.targets import Target
class CustomTarget(Target):
"""Custom refinement target with automatic gradient computation."""
name = 'custom_lsq'
def __init__(self, refinement):
super().__init__(refinement)
def forward(self):
# Define your loss - gradients computed automatically!
F_calc = self.refinement.model.get_F_calc()
F_obs = self.refinement.reflection_data.F
loss = torch.mean((torch.abs(F_calc) - F_obs) ** 2)
return loss
# Create refinement and register custom target
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
loss_state = refinement.create_loss_state()
custom_target = CustomTarget(refinement)
loss_state.register_target(custom_target.name, custom_target)
loss_state.set_weight(custom_target.name, 3.0)
print(f"Registered targets: {len(loss_state.targets)}")
print(f"Custom target weight: {loss_state.weights['custom_lsq']}")
...
Registered targets: ...
Custom target weight: 3.0
LossState Workflow
The LossState object manages targets, weights, and metadata for refinement:
from torchref import LBFGSRefinement, ROOT_TORCHREF
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
# Create and populate loss state
loss_state = refinement.create_loss_state()
refinement.add_target_info_to_state(loss_state)
refinement.populate_state_meta(loss_state)
refinement.update_weights(loss_state)
# Compute total loss
loss = loss_state.aggregate()
print(f"Total loss: {loss.item():.2f}")
# Gradients computed automatically
loss.backward()
print(f"Gradients computed for xyz: {refinement.model.xyz.refinable_params.grad is not None}")
...
Total loss: ...
Gradients computed for xyz: True
Saving and Loading State
Save complete refinement state for later continuation:
import torch
from torchref import LBFGSRefinement, ROOT_TORCHREF
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
# Get state dict (can be saved with torch.save)
state = refinement.state_dict()
print(f"State dict keys: {len(state)} entries")
...
State dict keys: ... entries
GPU Acceleration
Move computations to GPU for faster refinement:
import torch
from torchref import LBFGSRefinement, ROOT_TORCHREF
refinement = LBFGSRefinement(
data_file=f"{ROOT_TORCHREF}/example_notebooks/1DAW.mtz",
pdb=f"{ROOT_TORCHREF}/example_notebooks/1DAW.pdb",
)
if torch.cuda.is_available():
refinement.cuda()
print(f"Model on GPU: {refinement.model.xyz.device}")
else:
print("CUDA not available, using CPU")
...
Command Line Interface
For quick refinements from the command line:
torchref-refine -f reflections.mtz -s structure.pdb --output refined.pdb
Next Steps
See
example_notebooks/basic_usage.ipynbfor a complete tutorialSee
example_notebooks/code_examples.ipynbfor more code patternsSee
example_notebooks/target_exploration.ipynbfor custom targets