torchref.utils package
Utility functions and classes for TorchRef.
This module provides: - TensorMasks and TensorDict for managing tensor collections - Debugging utilities and mixins - Statistics formatting and tracking - Hyperparameter management - Gradient norm computation - PDB/selection parsing utilities
Example
from torchref.utils import TensorMasks, DebugMixin, gradnorm
# Create tensor masks for parameter selection
masks = TensorMasks()
masks['backbone'] = backbone_mask
# Use debugging mixin in your class
class MyRefinement(DebugMixin):
pass
# Compute gradient norm
grad_norm = gradnorm(loss, model.parameters())
- class torchref.utils.ParameterFingerprint(params=())[source]
Bases:
objectLightweight fingerprint for detecting parameter changes.
Captures (data_ptr, _version, numel) per tensor. Comparison is O(n_params) integer comparisons — much cheaper than SHA-1 hashing.
- class torchref.utils.CachedForwardMixin[source]
Bases:
objectMixin that caches
forward()results with automatic invalidation.Overrides
__call__to return a cached result when the module’s parameters, buffers, and call arguments have not changed since the last invocation — and no backward pass has propagated through the cached output.Cache invalidation triggers:
Any parameter or buffer
data_ptror_versionchange (covers optimizer in-place updates and mask/parameter replacement).Input tensor
data_ptror_versionchange, or non-tensor argument value change.A backward pass through the cached output (increments generation counter via a gradient hook).
The cached tensor retains its autograd graph — gradients flow correctly on the first backward pass, after which the cache is invalidated.
- class torchref.utils.DeviceMixin[source]
Bases:
objectUnified device/dtype movement.
Inherit alongside
nn.Module(place beforenn.Modulein the MRO):class Foo(DeviceMixin, nn.Module): ...
Or use on a plain Python class / dataclass:
@dataclass class Bar(DeviceMixin): data: torch.Tensor
All of
.to(),.cuda(),.cpu(),.float(),.double(),.half()route through_apply(), which:invokes
nn.Module._applywhen applicable so parameters, buffers and child modules are moved by the standard PyTorch path,walks
self.__dict__to pick up plain tensor attributes, nested containers and non-Module sub-objects,calls
reset_forward_cache()andreset_cache()if either is defined.
- torchref.utils.DeviceMovementMixin
alias of
DeviceMixin
- torchref.utils.resolve_device(*modules, device=None)[source]
Resolve a single device from N device-bearing modules.
Each
modulemust expose.deviceand accept.to(device)(satisfied bytorch.nn.Moduleand bytorchref.utils.DeviceMixinnon-Module subclasses such asCell).Noneentries are skipped silently so empty-init paths can pass through optional submodules —resolve_device(model, data)works whether or notdataisNone.Resolution order
If
deviceis given, every non-Nonemodule is moved to it and it is returned. No warning is emitted (the caller has made an explicit choice).Otherwise, after dropping
Noneentries, if no modules remain,torchref.config.get_default_device()is returned.The first remaining module’s device is the target. Any other module on a different device is moved to the target and a
UserWarningis emitted once for the call.
The “first module wins” rule is intentional: callers express precedence by argument order.
- param *modules:
Device-bearing modules.
Noneentries are skipped.- param device:
Explicit override. If provided, all non-
Nonemodules are moved to it and it is returned.- type device:
torch.device or str, optional
- returns:
The resolved device.
- rtype:
torch.device
Examples
Empty call returns the configured default:
>>> resolve_device() device(type='cpu')
Explicit override moves everything:
>>> resolve_device(model, data, device='cpu') device(type='cpu')
Auto-reconcile with first-wins precedence:
>>> resolve_device(cuda_model, cpu_data) device(type='cuda') # cpu_data has been moved to cuda
- class torchref.utils.TensorMasks(data=None, device=None)[source]
Bases:
DeviceMixin,dictA dictionary for managing boolean mask tensors with device support.
This is a lightweight dict subclass that: - Ensures all tensors are boolean dtype - Supports device movement via to(), cuda(), cpu() - Provides combined mask via __call__()
- Parameters:
data (dict, optional) – Initial mask data.
device (str or torch.device, optional) – Device for tensors. Defaults to the configured device.current.
Examples
masks = TensorMasks(device='cuda') masks['valid'] = torch.ones(100, dtype=torch.bool) masks['rfree'] = rfree_flags > 0 combined = masks() # Get combined mask (AND of all) masks.cpu() # Move all to CPU
- class torchref.utils.TensorDict(initial_dict=None)[source]
Bases:
ModuleA dictionary-like container for PyTorch tensors that: - Supports standard dict syntax - Automatically moves with the module - Registers tensors as buffers so they are included in state_dict
- class torchref.utils.ModuleReference(module)[source]
Bases:
objectA wrapper class to hold references to PyTorch modules without registering them.
When you assign a nn.Module to an attribute of another nn.Module, PyTorch automatically registers it as a submodule, which adds its parameters to the parent’s parameter tree. This wrapper prevents tlog_normal_stdhat automatic registration.
This is useful when you want to:
Hold references to modules without including their parameters
Avoid circular dependencies in the module tree
Reference external modules that should be managed separately
- _wrapped_module
The wrapped PyTorch module.
- Type:
Examples
model = MyModel() scaler = Scaler() scaler._model = ModuleReference(model) # Won't register as submodule # Access the module via .module property output = scaler._model.module(input_data)
- __init__(module)[source]
Wrap a module to prevent automatic registration.
- Parameters:
module (torch.nn.Module) – The PyTorch module to wrap.
- property module
Access the wrapped module.
- torchref.utils.save_map(array, cell, filename)[source]
Save a 3D map to a CCP4 file.
- Parameters:
array (numpy.ndarray or torch.Tensor) – 3D array representing the map.
cell (list, tuple, numpy.ndarray, torch.Tensor, or gemmi.UnitCell) – Unit cell parameters [a, b, c, alpha, beta, gamma].
filename (str) – Output CCP4 file name.
- Returns:
True if save was successful.
- Return type:
- torchref.utils.sanitize_pdb_dataframe(pdb, verbose=0)[source]
Sanitize a PDB DataFrame to ensure unique atom identifiers.
This function fixes common issues in PDB/CIF files:
HETATM records (especially waters) with duplicate resseq values (e.g., all 0)
Residue names longer than 3 characters (truncates to 3)
Ensures unique (chainid, resseq, name, altloc) combinations
- Parameters:
pdb (pandas.DataFrame) – DataFrame with PDB data (must have columns: ATOM, chainid, resseq, name, altloc, resname, serial).
verbose (int, default 0) – Verbosity level (0=silent, 1=info, 2=debug).
- Returns:
Sanitized DataFrame with unique atom identifiers.
- Return type:
Examples
from torchref.model import Model from torchref.utils import sanitize_pdb_dataframe model = Model() model.load_cif('structure.cif') model.pdb = sanitize_pdb_dataframe(model.pdb, verbose=1)
- torchref.utils.parse_phenix_selection(selection_string, pdb_df)[source]
Parse Phenix-style atom selection syntax and return a boolean mask.
Supports common Phenix selection keywords:
chain <id>: Select atoms by chain ID (e.g., “chain A”)
resseq <num>: Select atoms by residue sequence number (e.g., “resseq 10”)
resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”)
resname <name>: Select atoms by residue name (e.g., “resname ALA”)
name <atom>: Select atoms by atom name (e.g., “name CA”)
element <elem>: Select atoms by element (e.g., “element C”)
altloc <id>: Select atoms by alternate location (e.g., “altloc A”)
all: Select all atoms
not <selection>: Negate selection
<sel1> and <sel2>: Intersection of selections
<sel1> or <sel2>: Union of selections
Parentheses for grouping: (selection)
- Parameters:
selection_string (str) – Phenix-style selection string.
pdb_df (pandas.DataFrame) – DataFrame containing atomic data with columns: ‘chainid’, ‘resseq’, ‘resname’, ‘name’, ‘element’, ‘altloc’.
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates selected atoms.
- Return type:
- Raises:
ValueError – If selection syntax is invalid.
Examples
# Select chain A mask = parse_phenix_selection("chain A", pdb_df) # Select residues 10-20 in chain A mask = parse_phenix_selection("chain A and resseq 10:20", pdb_df) # Select all CA atoms mask = parse_phenix_selection("name CA", pdb_df) # Select backbone atoms mask = parse_phenix_selection("name CA or name C or name N or name O", pdb_df) # Select everything except water mask = parse_phenix_selection("not resname HOH", pdb_df) # Use parentheses for grouping mask = parse_phenix_selection("chain A and (name CA or name CB)", pdb_df)
- torchref.utils.create_selection_mask(selection_string, pdb_df, current_mask=None, mode='set')[source]
Create or modify a refinable mask based on a Phenix-style selection.
This function allows you to update refinable masks by selecting specific atoms using Phenix-style syntax. You can either replace the current mask, add to it, or remove from it.
- Parameters:
selection_string (str) – Phenix-style selection string.
pdb_df (pandas.DataFrame) – DataFrame containing atomic data.
current_mask (torch.Tensor, optional) – Current refinable mask. If None, starts with all False.
mode (str, default 'set') –
How to combine with current mask:
’set’: Replace mask with selection (default)
’add’: Add selection to current mask (OR operation)
’remove’: Remove selection from current mask (AND NOT operation)
- Returns:
Updated boolean mask of shape (n_atoms,).
- Return type:
- Raises:
ValueError – If mode is not one of ‘set’, ‘add’, ‘remove’.
Examples
# Create new mask selecting chain A mask = create_selection_mask("chain A", pdb_df, mode='set') # Add residues 10-20 to existing mask mask = create_selection_mask("resseq 10:20", pdb_df, current_mask=mask, mode='add') # Remove water from mask mask = create_selection_mask("resname HOH", pdb_df, current_mask=mask, mode='remove')
- class torchref.utils.DebugMixin[source]
Bases:
objectMixin class that adds debugging capabilities to modules.
When an error occurs, call print_debug_summary() to get a comprehensive overview of the module’s state including:
All attributes and their types
Tensor shapes, dtypes, and devices
DataFrame/array shapes
Other object information
- print_debug_summary(title=None, file=<_io.TextIOWrapper name='<stderr>' mode='w' encoding='utf-8'>)[source]
Print a comprehensive debug summary of this module’s state.
- Parameters:
title (str, optional) – Title for the summary.
file (file-like, default sys.stderr) – File to write output to.
- torchref.utils.print_module_summary(module, title=None, file=<_io.TextIOWrapper name='<stderr>' mode='w' encoding='utf-8'>)[source]
Print debug summary for any module.
Standalone function to print debug information for modules that may or may not have the DebugMixin.
- class torchref.utils.StatEntry(value, verbosity=1)[source]
Bases:
objectA statistics entry with value and verbosity level.
JSON serializable - when serialized, only the value is written.
- value
The statistic value.
- Type:
Any
- __init__(value, verbosity=1)
- torchref.utils.filter_stats(stats, max_verbosity)[source]
Filter stats dictionary to only include entries at or below max_verbosity.
- torchref.utils.flatten_stats(stats, prefix='')[source]
Flatten nested stats dict into flat dict with dotted keys.
- torchref.utils.format_stats_table(stats, title='', indent=2)[source]
Format stats dictionary as a printable table.
- class torchref.utils.HyperparameterMixin[source]
Bases:
objectMixin class providing hyperparameter registration and tracking.
Use this with nn.Module to add hyperparameter tracking capabilities. Hyperparameters are stored as buffers but tracked separately.
Examples
class MyTarget(HyperparameterMixin, nn.Module): def __init__(self, sigma=1.0, target_value=0.0): nn.Module.__init__(self) HyperparameterMixin.__init__(self) self.register_hyperparameter('sigma', sigma) self.register_hyperparameter('target_value', target_value) target = MyTarget(sigma=2.5) dict(target.hyperparameters()) # {'sigma': tensor(2.5), 'target_value': tensor(0.0)}
- register_hyperparameter(name, value, persistent=True)[source]
Register a hyperparameter.
Hyperparameters are stored as buffers (so they move with .to(device)) but tracked separately for easy access and state_dict operations.
- Parameters:
Examples
self.register_hyperparameter('sigma', 2.0) self.sigma # Access via property (if defined) or _hp_sigma # 2.0
- get_hyperparameter(name)[source]
Get a hyperparameter value.
- Parameters:
name (str) – Name of the hyperparameter.
- Returns:
The hyperparameter tensor.
- Return type:
- hyperparameters(recurse=True)[source]
Return an iterator over module hyperparameters.
- Parameters:
recurse (bool, optional) – If True, include hyperparameters from submodules. Default is True.
- Yields:
tuple – (name, tensor) pairs for each hyperparameter.
Examples
for name, hp in module.hyperparameters(): print(f"{name}: {hp.item()}")
- named_hyperparameters(prefix='', recurse=True)[source]
Return an iterator over module hyperparameters with names.
Same as hyperparameters() but allows prefix specification.
- hyperparameter_state_dict(prefix='')[source]
Return a dictionary containing only hyperparameters.
Unlike state_dict() which includes all buffers and parameters, this only returns the hyperparameters.
- Parameters:
prefix (str, optional) – Prefix to prepend to names. Default is ‘’.
- Returns:
Dictionary mapping hyperparameter names to tensors.
- Return type:
Examples
hp_state = module.hyperparameter_state_dict() torch.save(hp_state, 'hyperparameters.pt')
- load_hyperparameter_state_dict(state_dict, strict=True)[source]
Load hyperparameters from a state dict.
- Parameters:
Examples
hp_state = torch.load('hyperparameters.pt') module.load_hyperparameter_state_dict(hp_state)
- hyperparameter_dict()[source]
Return hyperparameters as a simple Python dict of floats.
Useful for logging, serialization to JSON, etc.
- Returns:
Dictionary mapping hyperparameter names to float values.
- Return type:
Examples
params = module.hyperparameter_dict() import json json.dumps(params) # JSON serializable
- torchref.utils.convert_to_serializable(obj)[source]
Convert tensors and numpy arrays to JSON-serializable types.
Recursively walks dicts, lists, and tuples, converting
torch.Tensor,numpy.ndarray, and numpy scalar types to plain Python objects thatjson.dumpcan handle.
- torchref.utils.gradnorm(loss, parameters)[source]
Compute the gradient norm of a loss with respect to given parameters.
Performs a backward pass with graph retention and computes the RMS (root mean square) of all gradients concatenated together.
- Parameters:
loss (torch.Tensor) – The loss tensor to backpropagate.
parameters (iterable) – Iterable of model parameters (typically from model.parameters()).
- Returns:
The computed RMS gradient norm.
- Return type:
Notes
Uses retain_graph=True to allow subsequent backward passes. Only includes parameters that have gradients (skips None grads).
Examples
loss = model(input) grad_norm = gradnorm(loss, model.parameters()) print(f"Gradient norm: {grad_norm:.4f}")
- torchref.utils.validate_loss(loss, *, state=None, parameters=None, check_grads=True, context='', raise_on_fail=True, max_full_diagnostics=3)[source]
Check that
loss(and optionally grads / parameters) are finite.- Parameters:
loss (torch.Tensor) – The scalar loss returned by the closure. Must be a zero-dim or one-element tensor.
state (LossState, optional) – If provided, the diagnostic path re-runs
state.aggregate(log_values=True)to repopulate per-target losses and formats them viastate.format_breakdown(). Safe to omit for closures that don’t use a LossState (e.g. scalers, alignment).parameters (iterable of torch.Tensor, optional) – Parameters to inspect. When
check_grads=True, their gradients are checked for finiteness on the fast path. On failure, both parameters and their grads are reported with non-finite entry counts.check_grads (bool, default True) – Check parameter gradients for finiteness after backward(). This is the usual pathology (backward produces NaN even when forward was finite), so leave it on unless a hot benchmark proves it’s costly.
context (str, default "") – Short label written into the diagnostic header and returned in the warning / exception message (e.g.
"collection_difference_refine"). Also keys the per-context diagnostic budget.raise_on_fail (bool, default True) – If True, raise
NonFiniteLossErroron failure (strict mode). If False, print a warning and returnFalse— the caller is responsible for rejecting the LBFGS step (e.g. by zeroing grads and returning +inf so strong-Wolfe backtracks).max_full_diagnostics (int, default 3) – Per-context budget for the full per-target breakdown. After this many failures in the same context, only a compact one-line warning is printed. Prevents log flooding when LBFGS bounces around a persistent NaN region. Pass
0to always print compact.
- Returns:
Trueif everything is finite (happy path),Falseotherwise. Whenraise_on_fail=True, a False result raises instead of returning.- Return type:
- Raises:
NonFiniteLossError – If
raise_on_fail=Trueand any of loss / grads / params is non-finite.
- exception torchref.utils.NonFiniteLossError[source]
Bases:
RuntimeErrorRaised when a refinement step produces non-finite loss, grads, or params.
- torchref.utils.reset_diagnostic_budget(context=None)[source]
Reset the failure counter used to stride full diagnostics.
- Parameters:
context (str, optional) – Reset a single context’s counter. If omitted, reset all.
- torchref.utils.collect_loss_leaves(losses)[source]
Return the set of leaf
nn.Parameter``s that gradient will accumulate into when ``backward()is called on the given loss(es).Walks the autograd graph from each root tensor’s
grad_fnand finds everyAccumulateGradnode, collecting its.variablewhen it is annn.Parameter.Multiple roots are unioned via a single shared traversal so that shared subgraphs (e.g. two losses both depending on the same model forward) are walked exactly once.
- Parameters:
losses (Tensor | Iterable[Tensor] | Mapping[str, Tensor]) – One or more loss tensors.
- Returns:
Leaf parameters that backward would accumulate gradient into. A leaf with
requires_grad=Falsedoes not appear (noAccumulateGradnode is created for it). Detached subtrees contribute nothing.- Return type:
set of nn.Parameter
Submodules
- torchref.utils.autograd_introspection module
- torchref.utils.autograd_ops module
- torchref.utils.caching module
- torchref.utils.debug_utils module
- torchref.utils.device_mixin module
- torchref.utils.device_resolution module
- torchref.utils.gradnorm module
- torchref.utils.hyperparameters module
HyperparameterMixinHyperparameterMixin.__init__()HyperparameterMixin.register_hyperparameter()HyperparameterMixin.get_hyperparameter()HyperparameterMixin.set_hyperparameter()HyperparameterMixin.hyperparameters()HyperparameterMixin.named_hyperparameters()HyperparameterMixin.hyperparameter_state_dict()HyperparameterMixin.load_hyperparameter_state_dict()HyperparameterMixin.hyperparameter_dict()HyperparameterMixin.print_hyperparameters()
create_hyperparameter_property()
- torchref.utils.loss_validation module
- torchref.utils.pse module
- torchref.utils.serialization module
- torchref.utils.stats module
- torchref.utils.timing module
- torchref.utils.utils module