torchref.refinement package
Refinement module for crystallographic structure refinement.
This module provides the core refinement framework including: - Refinement classes for running optimization - Target (loss) functions for X-ray, geometry, and ADP restraints - Weighting schemes for balancing loss components - Loss aggregation and state tracking
Example
Basic refinement:
from torchref.refinement import LBFGSRefinement
refinement = LBFGSRefinement(
data_file='reflections.mtz',
pdb='structure.pdb',
)
refinement.refine_everything(macro_cycles=5)
Access targets and weighting schemes:
from torchref.refinement.targets import XrayTarget, BondTarget
from torchref.refinement.weighting import ComponentWeighting
- class torchref.refinement.Refinement(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]
Bases:
DeviceMixin,DebugMixin,ModuleRefinement class to handle the overall crystallographic refinement process.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
refinement = Refinement() # Creates empty shell with submodules refinement.load_state_dict(torch.load('refinement.pt'))
Full initialization with file paths:
refinement = Refinement(data_file='data.mtz', pdb='model.pdb')
- Parameters:
data_file (str, optional) – Path to MTZ or CIF file containing reflection data.
pdb (str, optional) – Path to PDB or CIF file containing initial model.
cif (str, optional) – Path to CIF file for restraints.
verbose (int, optional) – Verbosity level. Default is 1.
max_res (float, optional) – Maximum resolution for reflections.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.
nbins (int, optional) – Number of resolution bins. Default is 10.
- device
Computation device.
- Type:
- reflection_data
Reflection data container.
- Type:
- weighter
Loss weighting module.
- Type:
LossWeightingModule
- __init__(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]
Initialize Refinement.
If data_file and pdb are provided, fully initializes the refinement. If not provided (empty init), creates a shell with empty submodules ready for load_state_dict().
- Parameters:
data_file (str, optional) – Path to MTZ or CIF file containing reflection data.
pdb (str, optional) – Path to PDB or CIF file containing initial model.
cif (str, optional) – Path to CIF file for restraints.
verbose (int, optional) – Verbosity level. Default is 1.
max_res (float, optional) – Maximum resolution for reflections.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.
nbins (int, optional) – Number of resolution bins. Default is 10.
- set_xray_target_mode(mode)[source]
Change the X-ray target mode.
- Parameters:
mode (str) – X-ray target mode. Options are ‘gaussian’, ‘ls’, or ‘ml’.
- property data
Expose reflection_data as ‘data’ for weighting module compatibility.
- Returns:
The reflection data container.
- Return type:
- property loss_state: LossState
Get or create the persistent LossState.
The LossState is created once and reused across refinement cycles. Targets are registered once; weights are updated each cycle.
- Returns:
The persistent loss state with targets registered.
- Return type:
- property logger: Logger
Get or create the Logger for this refinement.
- Returns:
Logger instance linked to the persistent LossState.
- Return type:
- reset_loss_state()[source]
Reset the persistent LossState and Logger.
Call this if targets need to be re-registered (e.g., after changing target modes or reinitializing targets).
- parameters(recurse=True)[source]
Return unique parameters from this module and all submodules.
Uses the default Module.parameters() to gather parameters, then removes duplicates while preserving order to avoid passing the same tensor multiple times to the optimizer.
- adp_loss()[source]
Compute total ADP loss using TotalADPTarget.
This combines:
Bond-based similarity (SIMU-like)
Spread control (tighter than KL)
Bounds penalty
- Returns:
Total ADP loss value.
- Return type:
- nll_xray()[source]
Compute X-ray negative log-likelihood for work and test sets.
- Returns:
Tuple of (work_nll, test_nll) tensors.
- Return type:
- xray_loss_work()[source]
Compute X-ray loss on work set using instantiated target.
- Returns:
X-ray loss on work set.
- Return type:
- xray_loss_test()[source]
Compute X-ray loss on test set using instantiated target.
- Returns:
X-ray loss on test set.
- Return type:
- bond_loss()[source]
Compute bond length NLL via geometry_target.
- Returns:
Bond length NLL loss.
- Return type:
- torsion_loss()[source]
Compute torsion angle NLL via geometry_target.
- Returns:
Torsion angle NLL loss.
- Return type:
- geometry_loss()[source]
Compute total geometry NLL using TotalGeometryTarget.
- Returns:
Total geometry NLL loss.
- Return type:
- loss()[source]
Compute total loss using LossState pipeline.
Creates a LossState, populates meta, caches losses, updates weights, and returns the aggregated weighted loss.
- Returns:
Total weighted loss.
- Return type:
- setup_component_weighting()[source]
Set up component weighting with ResolutionWeighting + OverfittingWeighting.
- populate_state_meta(state)[source]
Populate LossState.meta with all model-level data.
Called once per macro cycle before weighting schemes are applied. This is the single location where refinement data is extracted into state.
- update_weights(state, multiply=False)[source]
Compute weights from component_weighting and update state. Weights are clipped to [0.01, 100.0] to avoid extreme values.
- create_loss_state()[source]
Create a configured LossState for optimization.
Deprecated since version Use: the loss_state property instead for the persistent state. This method is kept for backwards compatibility.
Sets up a LossState with all targets registered as callables with hierarchical naming (e.g., ‘geometry/bond’, ‘adp/simu’). Weights are applied from component_weighting.
- Usage:
from torchref.utils import validate_loss
state = refinement.create_loss_state() params = list(refinement.parameters())
# Log initial state state.aggregate(log_values=True)
# In an LBFGS closure, wrap with validate_loss so non-finite # losses warn + reject the step instead of poisoning the run. def closure():
optimizer.zero_grad() loss = state.aggregate() loss.backward() ok = validate_loss(
loss, state=state, parameters=params, context=”my_refinement”, raise_on_fail=False,
) if not ok:
- for p in params:
- if p.grad is not None:
p.grad.zero_()
return torch.full_like(loss.detach(), float(“inf”))
return loss
optimizer.step(closure)
# Log final state state.new_entry() state.aggregate(log_values=True)
- Returns:
Configured LossState with targets and weights.
- Return type:
- complete_loss_state()[source]
Update and return the persistent LossState.
Updates the persistent LossState with current meta, target info, cached losses, and weights. The state is reused across cycles.
The cached active-parameter leaf set is not refreshed here. Stale leaves are not a correctness hazard: a leaf that’s in the set but whose Parameter object was replaced externally (e.g. by
Model.freeze) just gets ignored by_freeze_graph_extras, which costs a marginal amount of wasted backward work but never produces wrong answers. If you do callModel.freeze/Model.unfreezebetween LossState creation and a refinement step, callstate.refresh_loss_leaves()explicitly.- Returns:
Complete LossState with targets, meta, losses, and weights.
- Return type:
- restraints_loss()[source]
Compute total geometry restraints loss.
- Returns:
Total geometry restraints loss.
- Return type:
- collect_metrics()[source]
Collect all metrics from component_weighting.stats().
This is the standard method for gathering refinement metrics for logging. Uses the centralized component_weighting module for all statistics. Returns full unfiltered stats - filtering is done at display time.
- Returns:
Dictionary with all metrics (unfiltered, with StatEntry objects).
- Return type:
- add_target_info_to_state(state)[source]
Add target information from geometry and ADP targets to LossState.meta.
Deprecated since version This: method is no longer needed. Use
complete_loss_state()instead, which handles all state setup in one call.
- collect_deposition_metadata(metadata=None)[source]
Collect refinement statistics into a RefinementMetadata object.
Reuses existing statistics from
collect_metrics(),get_rfactor(), and reflection data attributes.- Parameters:
metadata (RefinementMetadata, optional) – Existing metadata to merge with (e.g. from input file pass-through). Refinement statistics take precedence over pass-through values.
- Returns:
Metadata populated with final refinement statistics.
- Return type:
- write_out_pdb(out_pdb_path='refined_output.pdb', metadata=None)[source]
Write refined PDB with optional metadata header.
- Parameters:
out_pdb_path (str) – Output PDB file path.
metadata (RefinementMetadata, optional) – Metadata for PDB header. If None, auto-collected from refinement.
- write_out_cif(out_cif_path='refined_output.cif', metadata=None)[source]
Write refined coordinates as mmCIF with metadata.
- Parameters:
out_cif_path (str) – Output mmCIF file path.
metadata (RefinementMetadata, optional) – Metadata for mmCIF categories. If None, auto-collected from refinement.
- save_state(path)[source]
Save the complete state of the refinement to a file.
- Parameters:
path (str) – Path to save the state dictionary to.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1)[source]
Create a fully initialized Refinement from a state dictionary.
This is the recommended way to restore a Refinement from a saved state. It creates the proper submodules using their respective create_from_state_dict methods, then calls PyTorch’s default load_state_dict.
- Parameters:
state_dict (dict) – State dictionary from torch.save(refinement.state_dict(), …) or from loading a checkpoint file.
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
- Returns:
Fully initialized instance with restored state.
- Return type:
Examples
Save and load refinement state:
# Save torch.save(refinement.state_dict(), 'refinement.pt') # Load state = torch.load('refinement.pt') refinement = Refinement.create_from_state_dict(state) # Continue refinement rwork, rfree = refinement.get_rfactor() print(f"Restored at R-work={rwork:.4f}, R-free={rfree:.4f}")
- class torchref.refinement.LBFGSRefinement(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]
Bases:
RefinementLBFGS-based refinement subclass using the L-BFGS optimizer for fast convergence.
L-BFGS (Limited-memory BFGS) is a quasi-Newton optimization method that approximates the Hessian matrix, leading to much faster convergence than first-order methods.
Key advantages:
Converges in 1-2 macro cycles (vs 5+ for Adam)
Better final R-factors
More stable convergence
Automatically handles step size via line search
- Parameters:
target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, or ‘ml’). Default is ‘ml’.
*args – Passed to parent Refinement class.
**kwargs – Passed to parent Refinement class.
Examples
Basic usage:
from torchref.refinement import LBFGSRefinement refinement = LBFGSRefinement( data_file='data.mtz', pdb='model.pdb', target_mode='ml' ) refinement.refine(macro_cycles=2)
- LBFGS_DEFAULTS = {'history_size': 100, 'line_search_fn': 'strong_wolfe', 'lr': 1.0, 'max_iter': 20}
- __init__(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]
Initialize LBFGS refinement.
- Parameters:
target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, ‘ml’, ‘bhattacharyya’). Default is ‘bhattacharyya’.
sigma_m_scale (float, optional) – Global multiplier for σ_m in the Bhattacharyya target only. Ignored for other target modes. Default 1.0.
use_lossstate_scaler (bool, optional) – If True (default),
refine_scaler()uses the fullLossStatewith the body’s x-ray target — so scaler and body steps share one consistent loss. If False, falls back toScaler.refine_lbfgswhich minimises a standalonenll_xrayand can pull scales in a different direction than the body optimization.*args – Passed to parent Refinement class.
**kwargs – Passed to parent Refinement class.
- xray_loss()[source]
Compute X-ray loss using the instantiated target.
- Returns:
X-ray loss on work set.
- Return type:
- refine_scaler()[source]
Refine scaler parameters against the full refinement loss.
Builds the body
LossStateviacomplete_loss_state(), constructs a fresh LBFGS optimizer overlist(self.scaler.parameters()), and delegates toLossState.step(). Becausestate.stepdisablesrequires_gradon every loss leaf outside the optimizer’s intent set, xyz / adp / u / occupancy are pinned for the duration — only scaler parameters move.The critical property is that the x-ray target used here is the same one the body
refine_xyz()andrefine_adp()see. The legacyScaler.refine_lbfgs()minimises a standalonenll_xray+U^2penalty, which can pull scales in a different direction than abhattacharyyaormlbody loss and leaves the body to chase a scaler that disagrees with its own objective.When
use_lossstate_scaleris False, fall back to the legacyScaler.refine_lbfgs()path.
- refine_xyz()[source]
Refine Cartesian coordinates jointly with scaler parameters.
Scaler parameters (
log_scale,U, solvent terms) are included in the same LBFGS call asxyz. The joint curvature lets xyz steps see the scaler as an anchor — residuals the scaler can absorb do not have to be chased by atomic motion — and theadp/scaler_Uandadp/scaler_log_scalepriors bite on every step, so nothing in the scaler drifts between refine_xyz and refine_adp calls.- Returns:
State with history containing before/after loss values.
- Return type:
- refine_adp()[source]
Refine ADP / U / occupancy jointly with scaler parameters.
Scaler parameters (
log_scale,U, solvent terms) are included in the same LBFGS call as the ADP-block body parameters so the joint curvature can slide along the atomic-B / scaler-U degeneracy ridge together with theadp/scaler_Uregularizer. XYZ is left frozen.- Returns:
State with history containing before/after loss values.
- Return type:
- refine_joint()[source]
Joint LBFGS over every refinable parameter in one step.
Optimizes
xyz,adp,u,occupancy, and every scaler parameter (log_scale, anisotropicU, solvent terms) in a single LBFGS call. The joint curvature couples all of them through the same x-ray target and through theadp/scaler_U/adp/scaler_log_scalepriors — unlike alternating refine_xyz → refine_adp, there’s no “frozen partner” in either half that could lock the step into a locally bad direction.- Returns:
State with history containing before/after loss values.
- Return type:
- run_training_trajectory(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]
Run a training trajectory with policy-guided refinement.
This method runs a sequence of refinement steps using a policy to select component weights. It records state-action-reward tuples for training the policy with AWR or similar algorithms.
- Parameters:
policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode with sampling).
n_steps (int, optional) – Number of refinement steps in the trajectory (default: 10).
pdb_id (str, optional) – PDB identifier for recording.
structure_path (str, optional) – Path to structure file for recording.
sf_path (str, optional) – Path to structure factors file for recording.
seed (int, optional) – Random seed for reproducibility.
policy_version (str, optional) – Version identifier of the policy being used.
- Returns:
Complete trajectory with state-action-reward tuples.
- Return type:
- run_training_trajectory_joint(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]
Run a training trajectory with joint XYZ+ADP refinement.
Similar to
run_training_trajectory()but refines xyz, adp, u, and occupancy together in each step. The LBFGS curvature history is reset at the start of each policy step because the weight updates invalidate any prior Hessian approximation.- Parameters:
policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode).
n_steps (int, optional) – Number of refinement steps (default: 10).
pdb_id (str, optional) – Identifiers for trajectory recording.
structure_path (str, optional) – Identifiers for trajectory recording.
sf_path (str, optional) – Identifiers for trajectory recording.
seed (int, optional) – Random seed for reproducibility.
policy_version (str, optional) – Policy version identifier.
- Returns:
Complete trajectory with state-action-reward tuples.
- Return type:
- class torchref.refinement.LossState(device=<factory>, targets=<factory>, weights=<factory>, history=<factory>, _losses=<factory>, _compilable=<factory>, _compiled_aggregate=None, _loss_leaves=<factory>, _resettable_modules=<factory>, meta=<factory>)[source]
Bases:
DeviceMixinHierarchical loss state with lazy evaluation.
- device
Computation device.
- Type:
- targets
Target functions keyed by hierarchical name (e.g., ‘geometry/bond’).
- Type:
Dict[str, Callable]
- weights
Weights keyed by name. Can be group weights (‘geometry’) or component weights (‘geometry/bond’).
- history
Log of computed values per aggregation call.
- Type:
List[Dict]
- get(key, default=None)[source]
Get value with default fallback.
- Parameters:
key (str) – Key to look up.
default (Any) – Value to return if key not found.
- Returns:
Value from meta, _losses, or default.
- Return type:
Any
- cache_losses(force=False)[source]
Cache all target losses.
Evaluates all registered targets and stores results in _losses.
- register_target(name, target, prefix=None, compile=False, probe=True)[source]
Register a target function.
Automatically detects combined targets (like TotalGeometryTarget, TotalADPTarget) and expands them into their component targets.
- Parameters:
name (str) – Hierarchical name (e.g., ‘geometry/bond’, ‘adp/simu’).
target (Callable) – Function that returns a loss tensor when called. Can also be a combined target with .items() method, which will be auto-expanded.
prefix (str, optional) – Prefix to prepend to the name (e.g., ‘model1’ -> ‘model1/geometry/bond’). Useful for registering targets from multiple models in the same state.
compile (bool) – If True, mark this target (or all its sub-targets if combined) as eligible for the compiled aggregate closure built by compile_aggregate().
probe (bool) – If True (default), run the target’s forward once, walk the autograd graph, and merge the resulting leaf set into
self._loss_leaves. The target’s dependencies (model loaded, data attached, etc.) must therefore be in place before registration. Setprobe=Falseto skip — the leaf-set entry for this target will be empty, sostep()/run()will not auto-disable any leaves on its account. Useful only for targets whose forward genuinely cannot be called at registration time.
- Returns:
Self for chaining.
- Return type:
- register_targets(targets, prefix=None, compile=False, probe=True)[source]
Register multiple targets from a component target or dict.
For targets with a .name attribute, uses target.name as the key. For plain callables, uses the dict key.
- Parameters:
targets (dict) – Dictionary of name -> target mappings.
prefix (str, optional) – Prefix to prepend to all target names.
compile (bool) – If True, propagate the compile flag to all sub-targets.
probe (bool) – Forwarded to
register_target().
- get_effective_weight(name)[source]
Get effective weight for a target, including group weights.
For ‘geometry/bond’, returns: weights[‘geometry’] * weights[‘geometry/bond’] Missing weights default to 1.0.
- mark_compilable(names)[source]
Mark already-registered targets as eligible for the compiled aggregate.
- compile_aggregate(**compile_kwargs)[source]
Build and cache a torch.compile’d closure over all compilable targets.
Must be called after all targets and weights have been registered. Re-call if weights or compilable targets change (or call reset_compiled_aggregate()).
- Parameters:
**compile_kwargs – Keyword arguments forwarded to torch.compile. Defaults to fullgraph=False so partial-graph fallback is allowed.
- Returns:
Self for chaining.
- Return type:
- reset_compiled_aggregate()[source]
Clear the cached compiled closure (e.g. after changing weights).
- log(name, value)[source]
Log a value to the current history entry.
Creates a new history entry if needed.
- Parameters:
name (str) – Key for the logged value.
value (Any) – Value to log. Tensors are converted to Python floats.
- aggregate(log_values=False)[source]
Evaluate all targets and compute weighted sum.
When compile_aggregate() has been called and log_values=False, the compilable targets are evaluated through a single torch.compile’d closure for improved performance. With log_values=True all targets run eagerly so per-target losses are available in _losses.
- Parameters:
log_values (bool) – If True, log all losses, weights, and total to history.
- Returns:
Total weighted loss.
- Return type:
- get_loss(name)[source]
Get a cached loss value (after aggregate() was called).
- Parameters:
name (str) – Target name.
- Returns:
Cached loss, or None if not computed.
- Return type:
torch.Tensor or None
- active_parameters()[source]
Return the set of leaf ``nn.Parameter``s that registered targets’ backward passes will accumulate gradient into.
Populated incrementally by
register_target()via a one-shot probe forward + autograd graph walk — calling this method does not run any forward, walk any graph, or evaluate any target. The result is conservative: a target whose weight is later set to 0 still contributes its leaves here, which is harmless for the freezing logic instep()(it can only over-freeze, never under-freeze).
- refresh_loss_leaves()[source]
Re-probe every registered target and rebuild
_loss_leavesand the resettable-modules cache.Use this after external code has replaced parameter identity on the underlying model — for example after
Model.freeze()/Model.unfreeze()(which rebuildrefinable_paramstensors). Under normalstep()/run()usage no parameter identity ever changes, so this method is rarely needed.
- reset_caches()[source]
Call
reset_cache()on every registered target’s submodules that expose one. Invoked automatically at the end ofstep().
- restore_loss_leaf_grads()[source]
Unconditionally re-enable
requires_gradon every leaf inself._loss_leaves. Called at the end ofstep()so the next call sees a clean, fully-differentiable model regardless of what state the previous step (or external code) left things in.
- run(optimizer, log=False, nsteps=1, *, context='loss_state.step')[source]
Run a single
optimizer.step(closure).Builds the closure, validates each loss for finiteness via
torchref.utils.validate_loss(), and on failure zeros the gradients and returns+infso the strong-Wolfe line search backtracks. Automatically disablesrequires_gradon any leaf that the loss touches but the optimizer was not constructed with — autograd then prunes those subgraphs from the backward pass.Technically this should work with all optimzers in pytorch that support closures but it has only been tested for LBFGS so far. The closure is built to be as general as possible, so if you have a custom optimizer that supports closures it should “just work” with this method.
Every collected
reset_cache-bearing submodule is reset before the optimizer step so the closure’s first forward sees a clean cache (a previous rejected closure may have stored a NaN/inf forward result that the fingerprint would happily serve again if parameter values haven’t changed).After the the run we call maintenance on all targets.
On exit,
requires_grad=Trueis unconditionally re-enabled on every leaf inself._loss_leaves— defending against state bleeding between successive refinement methods.- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to step. Its
param_groupsdefine the intent — the leaves the caller actually wants to update.log (bool) – If True, calls
aggregate(log_values=True)before and after the optimization loopnsteps (int) – Number of steps to run (default 1). Only the first step’s closure caching is enabled between multiple steps. If you want to run truly independent steps, call this method multiple times with nsteps=1. This adds overhead but might be desirable if the overhead is negligible anyway.
context (str) – Diagnostic label forwarded to
validate_loss.
- Returns:
The loss tensor from the last accepted closure call, or
Noneif no closure call succeeded (every call produced non-finite loss).- Return type:
torch.Tensor or None
- step(optimizer, *args, **kwargs)[source]
Convenience method that calls
run()with 1 step.- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to run.
*args – Forwarded to
run().**kwargs – Forwarded to
run().
- get_breakdown()[source]
Get breakdown of losses by group.
- Returns:
Nested dict: {group: {component: {‘loss’: …, ‘weight’: …, ‘weighted’: …}}}
- Return type:
Dict
- format_breakdown()[source]
Return per-target loss / weight / weighted / finite as a string.
One row per target currently in
self._losses(populated by the most recent eageraggregate()call). Used by bothsummary()andtorchref.utils.validate_loss()so the diagnostic format does not drift.
- to(*args, **kwargs)[source]
Move via
DeviceMixin; honour an explicit device when no tensors exist yet.
- __init__(device=<factory>, targets=<factory>, weights=<factory>, history=<factory>, _losses=<factory>, _compilable=<factory>, _compiled_aggregate=None, _loss_leaves=<factory>, _resettable_modules=<factory>, meta=<factory>)
- class torchref.refinement.Logger(state, verbose=1, pattern='.*', _records=<factory>, _labels=<factory>)[source]
Bases:
objectRefinement logging with verbosity-aware stat reporting.
Integrates with LossState to record, compare, and display refinement stats. Supports regex-based filtering of target names.
- Parameters:
Examples
Basic usage:
from torchref.refinement import LossState, Logger state = LossState(device=device) state.register_target("xray/work", xray_target) state.register_target("geometry/bond", bond_target) logger = Logger(state, verbose=1) # Record before refinement logger.record(label="before_xyz") # ... run refinement ... # Record after and compare logger.record(label="after_xyz") logger.compare(title="XYZ Refinement")
Filtering by pattern:
# Show only X-ray targets logger.current(pattern="xray.*") # Create a logger that only tracks geometry geom_logger = Logger(state, pattern="geometry.*")
- record(label=None)[source]
Record current refinement state.
Gathers stats from all targets in LossState, stores in history. Uses the instance’s pattern filter.
- compare(label_before=None, label_after=None, pattern=None, title='Refinement Comparison')[source]
Print comparison between two recorded states.
If labels not provided, compares last two records.
- Parameters:
label_before (str, optional) – Label of “before” state. Default: second-to-last record.
label_after (str, optional) – Label of “after” state. Default: last record.
pattern (str, optional) – Regex to filter which targets to display. Default: use instance pattern.
title (str, optional) – Title for the comparison output. Default: “Refinement Comparison”.
- current(pattern=None, title='Current State')[source]
Print the current refinement state.
Uses latest recorded state, or records new one if none exist.
- __init__(state, verbose=1, pattern='.*', _records=<factory>, _labels=<factory>)
- class torchref.refinement.Target(verbose=0, **kwargs)[source]
Bases:
DeviceMixin,ModuleAbstract base class for all target functions.
All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.
Supports empty initialization for state_dict loading:
target = Target() # Creates empty shell target.load_state_dict(torch.load('target.pt'))
- LossState Integration:
Targets can work with LossState for the new pipeline:
state = target.add_to_state(state) # Adds loss to state
- Parameters:
verbose (int, optional) – Verbosity level. Default is 0.
- __init__(verbose=0, **kwargs)[source]
Initialize target.
- Parameters:
verbose (int, optional) – Verbosity level. Default is 0.
- add_to_state(state)[source]
Compute loss and add it to the LossState.
This method enables the new LossState pipeline pattern where targets receive a state object, compute their loss, add it to the state, and return the state for chaining.
- maintenance()[source]
Between-step housekeeping hook (no-op by default).
LossStatecalls this on every registered target after each successful outer optimizer step returns. Targets override this to rebuild stale internal state (VDW pair lists, solvent masks, etc.) based on how far parameters have drifted since the last refresh.Contract
Must be idempotent: calling it multiple times in a row on an unchanged model should not mutate the target.
Fast path first: cheap staleness check up front, expensive rebuild only when strictly necessary.
LossStatecalls this every outer step — the happy-path cost is paid every time.Must not raise on routine drift. If a rebuild fails, let the exception propagate — that’s a real bug.
- class torchref.refinement.DataTarget(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]
Bases:
TargetBase class for targets that need ReflectionData and optionally Model/Scaler.
This class provides a flexible interface for X-ray targets that can work in two modes:
With Model: Computes F_calc from the model on each forward pass
Without Model: Uses pre-computed F_calc passed directly
This decoupling allows targets to be used for: - Standard refinement (with model) - Analysis/scoring of pre-computed structure factors (without model) - Testing and validation workflows
All objects (model, data, scaler) are registered as proper submodules, allowing PyTorch to handle device movement and state_dict operations.
- Parameters:
data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().
model (Model or ModelFT, optional) – Reference to a Model object for F_calc computation. If None, F_calc must be provided to forward().
scaler (Scaler, optional) – Reference to the Scaler object for scaling F_calc.
verbose (int, optional) – Verbosity level. Default is 0.
target_value (float, optional) – Target value for this loss. Default is 0.0.
sigma (float, optional) – Sigma parameter for weighting. Default is 0.5.
- _data
Reference to the reflection data object (registered as submodule).
- Type:
- __init__(data=None, model=None, scaler=None, verbose=0, **kwargs)[source]
Initialize data target.
- Parameters:
data (ReflectionData, optional) – Reference to the ReflectionData object. Required for forward().
model (Model or ModelFT, optional) – Reference to Model object for F_calc computation. If None, F_calc must be provided when calling forward().
scaler (Scaler, optional) – Reference to the Scaler object.
verbose (int, optional) – Verbosity level. Default is 0.
- property data: ReflectionData
Access the reflection data object.
- get_fcalc(hkl=None, recalc=False)[source]
Compute structure factors from model.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.
recalc (bool, optional) – Force recalculation. Default is False.
- Returns:
Complex structure factors.
- Return type:
- Raises:
RuntimeError – If no model is set.
- get_fcalc_scaled(hkl=None, recalc=False, fcalc=None)[source]
Compute or scale structure factors.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.
recalc (bool, optional) – Force recalculation. Default is False.
fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.
- Returns:
Scaled complex structure factors.
- Return type:
- get_F_calc_scaled(hkl=None, recalc=False, fcalc=None)[source]
Compute scaled structure factor amplitudes.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices. If None, uses data’s hkl.
recalc (bool, optional) – Force recalculation. Default is False.
fcalc (torch.Tensor, optional) – Pre-computed structure factors. If provided, skips model computation.
- Returns:
Scaled structure factor amplitudes |F_calc|.
- Return type:
- get_rfactor()[source]
Compute R-factors using scaler.
- Returns:
(R_work, R_free) values.
- Return type:
- Raises:
RuntimeError – If no scaler is set.
- class torchref.refinement.ModelTarget(model=None, verbose=0, **kwargs)[source]
Bases:
TargetBase class for targets that only need a Model reference.
This class provides a simpler interface for geometry and ADP targets that don’t need access to reflection data or refinement machinery. Targets inherit from this class when they only need the atomic model.
The model is registered as a proper submodule, allowing PyTorch to handle device movement and state_dict operations automatically.
- Parameters:
- property restraints
Access model’s restraints (built lazily on first access).
Subpackages
- torchref.refinement.optimizers package
AdamWithAdaptiveNoiseExploratoryLBFGSLangevinSAMomentumStochasticSA- Submodules
- torchref.refinement.optimizers.adam_noise module
- torchref.refinement.optimizers.exploratory_lbfgs module
- torchref.refinement.optimizers.internal_coord_sa module
- torchref.refinement.optimizers.langevin_sa module
- torchref.refinement.optimizers.momentum_sa module
- torchref.refinement.optimizers.simulated_annealing module
- torchref.refinement.targets package
TargetModelTargetDataTargetDataTarget.nameDataTarget._modelDataTarget._dataDataTarget._scalerDataTarget.verboseDataTarget.nameDataTarget.__init__()DataTarget.modelDataTarget.dataDataTarget.scalerDataTarget.has_modelDataTarget.get_fcalc()DataTarget.get_fcalc_scaled()DataTarget.get_F_calc_scaled()DataTarget.get_rfactor()
gaussian_nll()von_mises_nll()adp_similarity_nll()XrayTargetGaussianXrayTargetMaximumLikelihoodXrayTargetLeastSquaresXrayTargetcreate_xray_target()DifferenceXrayTargetDifferenceXrayTarget.nameDifferenceXrayTarget.__init__()DifferenceXrayTarget.dataset_collectionDifferenceXrayTarget.data_lightDifferenceXrayTarget.data_darkDifferenceXrayTarget.model_lightDifferenceXrayTarget.model_darkDifferenceXrayTarget.scaler_lightDifferenceXrayTarget.scaler_darkDifferenceXrayTarget.hklDifferenceXrayTarget.get_delta_F_obs()DifferenceXrayTarget.get_delta_F_calc()DifferenceXrayTarget.forward()DifferenceXrayTarget.stats()
PhaseInformedDifferenceTargetRiceDifferenceTargetTaylorCorrectedDifferenceTargetGeometryTargetBondTargetAngleTargetTorsionTargetPlanarityTargetChiralTargetNonBondedTargetNonBondedHTargetRamachandranTargetADPTargetADPSimilarityTargetRigidBondTargetADPEntropyTargetADPLocalityTargetCombinedTargetsCombinedTargets._targetsCombinedTargets.__init__()CombinedTargets.targets()CombinedTargets.__getitem__()CombinedTargets.__contains__()CombinedTargets.keys()CombinedTargets.values()CombinedTargets.items()CombinedTargets.target_losses()CombinedTargets.forward()CombinedTargets.stats()CombinedTargets.get()CombinedTargets.add_to_state()
TotalGeometryTargetTotalADPTargetForceFieldTargetAmberTargetOccupancyFloorDiagnosticNegativeDensityPenaltyDisplacementRegularizerDifferenceAmplitudeRegularizerSampledMLPhaseTargetSampledMLPhaseTarget.nameSampledMLPhaseTarget.__init__()SampledMLPhaseTarget.nameSampledMLPhaseTarget.n_samplesSampledMLPhaseTarget.sigma_model_logSampledMLPhaseTarget.use_analyticalSampledMLPhaseTarget.use_antitheticSampledMLPhaseTarget.french_wilson_moments()SampledMLPhaseTarget.compute_sigma_phi()SampledMLPhaseTarget.forward()SampledMLPhaseTarget.stats()
SampledMLDifferenceTargetcreate_sampled_ml_target()create_sampled_ml_difference_target()RealSpaceTargetRealSpaceCorrelationTargetRealSpaceDifferenceTargetRealSpaceExtrapolatedTargetCoordinateSimilarityTarget- Subpackages
- Submodules
- torchref.refinement.targets.amber_target module
- torchref.refinement.targets.base module
- torchref.refinement.targets.combined module
- torchref.refinement.targets.difference module
- torchref.refinement.targets.forcefield_target module
- torchref.refinement.targets.occupancy_floor_diagnostic module
- torchref.refinement.targets.realspace module
- torchref.refinement.targets.sampled_ml_phase_target module
- torchref.refinement.targets.similarity module
- torchref.refinement.weighting package
BaseWeightingWeightingSchemeResolutionWeightingOverfittingWeightingManualWeightingComponentWeightingComponentWeighting.schemesComponentWeighting.__init__()ComponentWeighting.__getitem__()ComponentWeighting.__contains__()ComponentWeighting.keys()ComponentWeighting.values()ComponentWeighting.items()ComponentWeighting.add_scheme()ComponentWeighting.forward()ComponentWeighting.total_loss_from_state()ComponentWeighting.stats()
- Submodules
Submodules
- torchref.refinement.base_refinement module
RefinementRefinement.deviceRefinement.verboseRefinement.reflection_dataRefinement.modelRefinement.scalerRefinement.weighterRefinement.__init__()Refinement.set_xray_target_mode()Refinement.dataRefinement.loss_stateRefinement.loggerRefinement.reset_loss_state()Refinement.get_scales()Refinement.setup_scaler()Refinement.parameters()Refinement.get_fcalc()Refinement.get_fcalc_scaled()Refinement.adp_loss()Refinement.get_F_calc()Refinement.get_F_calc_scaled()Refinement.nll_xray()Refinement.xray_loss_work()Refinement.xray_loss_test()Refinement.bond_loss()Refinement.angle_loss()Refinement.torsion_loss()Refinement.geometry_loss()Refinement.loss()Refinement.setup_component_weighting()Refinement.populate_state_meta()Refinement.update_weights()Refinement.create_loss_state()Refinement.complete_loss_state()Refinement.xray_loss()Refinement.restraints_loss()Refinement.collect_metrics()Refinement.add_target_info_to_state()Refinement.get_rfactor()Refinement.update_outliers()Refinement.plot_fcalc_vs_fobs()Refinement.write_out_mtz()Refinement.collect_deposition_metadata()Refinement.write_out_pdb()Refinement.write_out_cif()Refinement.save_state()Refinement.load_state()Refinement.create_from_state_dict()
- torchref.refinement.lbfgs_refinement module
LBFGSRefinementLBFGSRefinement.target_modeLBFGSRefinement.LBFGS_DEFAULTSLBFGSRefinement.__init__()LBFGSRefinement.xray_loss()LBFGSRefinement.refine_scaler()LBFGSRefinement.refine_xyz()LBFGSRefinement.refine_adp()LBFGSRefinement.refine_joint()LBFGSRefinement.run_training_trajectory()LBFGSRefinement.run_training_trajectory_joint()LBFGSRefinement.refine()LBFGSRefinement.refine_everything()
- torchref.refinement.logger module
- torchref.refinement.loss_state module
LossStateLossState.deviceLossState.targetsLossState.weightsLossState.historyLossState.metaLossState.deviceLossState.targetsLossState.weightsLossState.historyLossState.metaLossState.__getitem__()LossState.__contains__()LossState.get()LossState.cache_losses()LossState.update_meta()LossState.register_target()LossState.register_targets()LossState.set_weight()LossState.set_weights()LossState.get_weight()LossState.get_effective_weight()LossState.mark_compilable()LossState.compile_aggregate()LossState.reset_compiled_aggregate()LossState.log()LossState.new_entry()LossState.get_history()LossState.aggregate()LossState.get_loss()LossState.active_parameters()LossState.refresh_loss_leaves()LossState.reset_caches()LossState.restore_loss_leaf_grads()LossState.run()LossState.step()LossState.get_breakdown()LossState.get_group_totals()LossState.format_breakdown()LossState.summary()LossState.to()LossState.clear()LossState.clear_history()LossState.__init__()
create_loss_state()