torchref.utils package

Utility functions and classes for TorchRef.

This module provides: - TensorMasks and TensorDict for managing tensor collections - Debugging utilities and mixins - Statistics formatting and tracking - Hyperparameter management - Gradient norm computation - PDB/selection parsing utilities

Example

from torchref.utils import TensorMasks, DebugMixin, gradnorm

# Create tensor masks for parameter selection
masks = TensorMasks()
masks['backbone'] = backbone_mask

# Use debugging mixin in your class
class MyRefinement(DebugMixin):
    pass

# Compute gradient norm
grad_norm = gradnorm(loss, model.parameters())
class torchref.utils.ParameterFingerprint(params=())[source]

Bases: object

Lightweight fingerprint for detecting parameter changes.

Captures (data_ptr, _version, numel) per tensor. Comparison is O(n_params) integer comparisons — much cheaper than SHA-1 hashing.

__init__(params=())[source]
matches(params)[source]

Return True if params have the same fingerprint.

class torchref.utils.CachedForwardMixin[source]

Bases: object

Mixin that caches forward() results with automatic invalidation.

Overrides __call__ to return a cached result when the module’s parameters, buffers, and call arguments have not changed since the last invocation — and no backward pass has propagated through the cached output.

Cache invalidation triggers:

  • Any parameter or buffer data_ptr or _version change (covers optimizer in-place updates and mask/parameter replacement).

  • Input tensor data_ptr or _version change, or non-tensor argument value change.

  • A backward pass through the cached output (increments generation counter via a gradient hook).

The cached tensor retains its autograd graph — gradients flow correctly on the first backward pass, after which the cache is invalidated.

__call__(*args, recalc=False, **kwargs)[source]

Return cached forward() result, or recompute on cache miss.

Parameters:

recalc (bool, optional) – If True, invalidate the cache and force recomputation. Not forwarded to forward().

reset_forward_cache()[source]

Manually invalidate the forward cache.

class torchref.utils.DeviceMixin[source]

Bases: object

Unified device/dtype movement.

Inherit alongside nn.Module (place before nn.Module in the MRO):

class Foo(DeviceMixin, nn.Module):
    ...

Or use on a plain Python class / dataclass:

@dataclass
class Bar(DeviceMixin):
    data: torch.Tensor

All of .to(), .cuda(), .cpu(), .float(), .double(), .half() route through _apply(), which:

  1. invokes nn.Module._apply when applicable so parameters, buffers and child modules are moved by the standard PyTorch path,

  2. walks self.__dict__ to pick up plain tensor attributes, nested containers and non-Module sub-objects,

  3. calls reset_forward_cache() and reset_cache() if either is defined.

to(*args, **kwargs)[source]
cuda(device=None)[source]
cpu()[source]
torchref.utils.DeviceMovementMixin

alias of DeviceMixin

torchref.utils.resolve_device(*modules, device=None)[source]

Resolve a single device from N device-bearing modules.

Each module must expose .device and accept .to(device) (satisfied by torch.nn.Module and by torchref.utils.DeviceMixin non-Module subclasses such as Cell). None entries are skipped silently so empty-init paths can pass through optional submodules — resolve_device(model, data) works whether or not data is None.

Resolution order

  1. If device is given, every non-None module is moved to it and it is returned. No warning is emitted (the caller has made an explicit choice).

  2. Otherwise, after dropping None entries, if no modules remain, torchref.config.get_default_device() is returned.

  3. The first remaining module’s device is the target. Any other module on a different device is moved to the target and a UserWarning is emitted once for the call.

The “first module wins” rule is intentional: callers express precedence by argument order.

param *modules:

Device-bearing modules. None entries are skipped.

param device:

Explicit override. If provided, all non-None modules are moved to it and it is returned.

type device:

torch.device or str, optional

returns:

The resolved device.

rtype:

torch.device

Examples

Empty call returns the configured default:

>>> resolve_device()
device(type='cpu')

Explicit override moves everything:

>>> resolve_device(model, data, device='cpu')
device(type='cpu')

Auto-reconcile with first-wins precedence:

>>> resolve_device(cuda_model, cpu_data)
device(type='cuda')  # cpu_data has been moved to cuda
class torchref.utils.TensorMasks(data=None, device=None)[source]

Bases: DeviceMixin, dict

A dictionary for managing boolean mask tensors with device support.

This is a lightweight dict subclass that: - Ensures all tensors are boolean dtype - Supports device movement via to(), cuda(), cpu() - Provides combined mask via __call__()

Parameters:
  • data (dict, optional) – Initial mask data.

  • device (str or torch.device, optional) – Device for tensors. Defaults to the configured device.current.

Examples

masks = TensorMasks(device='cuda')
masks['valid'] = torch.ones(100, dtype=torch.bool)
masks['rfree'] = rfree_flags > 0
combined = masks()  # Get combined mask (AND of all)
masks.cpu()  # Move all to CPU
__init__(data=None, device=None)[source]
__setitem__(key, tensor)[source]

Set mask tensor, ensuring boolean dtype and correct device.

reset_cache()[source]

Invalidate the cached combined mask.

__call__()[source]

Return combined mask (AND of all masks).

Returns:

Combined boolean mask, or None if no masks.

Return type:

torch.Tensor

class torchref.utils.TensorDict(initial_dict=None)[source]

Bases: Module

A dictionary-like container for PyTorch tensors that: - Supports standard dict syntax - Automatically moves with the module - Registers tensors as buffers so they are included in state_dict

__init__(initial_dict=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

keys()[source]
values()[source]
items()[source]
class torchref.utils.ModuleReference(module)[source]

Bases: object

A wrapper class to hold references to PyTorch modules without registering them.

When you assign a nn.Module to an attribute of another nn.Module, PyTorch automatically registers it as a submodule, which adds its parameters to the parent’s parameter tree. This wrapper prevents tlog_normal_stdhat automatic registration.

This is useful when you want to:

  • Hold references to modules without including their parameters

  • Avoid circular dependencies in the module tree

  • Reference external modules that should be managed separately

_wrapped_module

The wrapped PyTorch module.

Type:

torch.nn.Module

Examples

model = MyModel()
scaler = Scaler()
scaler._model = ModuleReference(model)  # Won't register as submodule
# Access the module via .module property
output = scaler._model.module(input_data)
__init__(module)[source]

Wrap a module to prevent automatic registration.

Parameters:

module (torch.nn.Module) – The PyTorch module to wrap.

property module

Access the wrapped module.

__getattr__(name)[source]

Forward attribute access to the wrapped module.

__call__(*args, **kwargs)[source]

Forward calls to the wrapped module.

torchref.utils.save_map(array, cell, filename)[source]

Save a 3D map to a CCP4 file.

Parameters:
Returns:

True if save was successful.

Return type:

bool

torchref.utils.sanitize_pdb_dataframe(pdb, verbose=0)[source]

Sanitize a PDB DataFrame to ensure unique atom identifiers.

This function fixes common issues in PDB/CIF files:

  1. HETATM records (especially waters) with duplicate resseq values (e.g., all 0)

  2. Residue names longer than 3 characters (truncates to 3)

  3. Ensures unique (chainid, resseq, name, altloc) combinations

Parameters:
  • pdb (pandas.DataFrame) – DataFrame with PDB data (must have columns: ATOM, chainid, resseq, name, altloc, resname, serial).

  • verbose (int, default 0) – Verbosity level (0=silent, 1=info, 2=debug).

Returns:

Sanitized DataFrame with unique atom identifiers.

Return type:

pandas.DataFrame

Examples

from torchref.model import Model
from torchref.utils import sanitize_pdb_dataframe
model = Model()
model.load_cif('structure.cif')
model.pdb = sanitize_pdb_dataframe(model.pdb, verbose=1)
torchref.utils.parse_phenix_selection(selection_string, pdb_df)[source]

Parse Phenix-style atom selection syntax and return a boolean mask.

Supports common Phenix selection keywords:

  • chain <id>: Select atoms by chain ID (e.g., “chain A”)

  • resseq <num>: Select atoms by residue sequence number (e.g., “resseq 10”)

  • resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”)

  • resname <name>: Select atoms by residue name (e.g., “resname ALA”)

  • name <atom>: Select atoms by atom name (e.g., “name CA”)

  • element <elem>: Select atoms by element (e.g., “element C”)

  • altloc <id>: Select atoms by alternate location (e.g., “altloc A”)

  • all: Select all atoms

  • not <selection>: Negate selection

  • <sel1> and <sel2>: Intersection of selections

  • <sel1> or <sel2>: Union of selections

  • Parentheses for grouping: (selection)

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • pdb_df (pandas.DataFrame) – DataFrame containing atomic data with columns: ‘chainid’, ‘resseq’, ‘resname’, ‘name’, ‘element’, ‘altloc’.

Returns:

Boolean tensor of shape (n_atoms,) where True indicates selected atoms.

Return type:

torch.Tensor

Raises:

ValueError – If selection syntax is invalid.

Examples

# Select chain A
mask = parse_phenix_selection("chain A", pdb_df)

# Select residues 10-20 in chain A
mask = parse_phenix_selection("chain A and resseq 10:20", pdb_df)

# Select all CA atoms
mask = parse_phenix_selection("name CA", pdb_df)

# Select backbone atoms
mask = parse_phenix_selection("name CA or name C or name N or name O", pdb_df)

# Select everything except water
mask = parse_phenix_selection("not resname HOH", pdb_df)

# Use parentheses for grouping
mask = parse_phenix_selection("chain A and (name CA or name CB)", pdb_df)
torchref.utils.create_selection_mask(selection_string, pdb_df, current_mask=None, mode='set')[source]

Create or modify a refinable mask based on a Phenix-style selection.

This function allows you to update refinable masks by selecting specific atoms using Phenix-style syntax. You can either replace the current mask, add to it, or remove from it.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • pdb_df (pandas.DataFrame) – DataFrame containing atomic data.

  • current_mask (torch.Tensor, optional) – Current refinable mask. If None, starts with all False.

  • mode (str, default 'set') –

    How to combine with current mask:

    • ’set’: Replace mask with selection (default)

    • ’add’: Add selection to current mask (OR operation)

    • ’remove’: Remove selection from current mask (AND NOT operation)

Returns:

Updated boolean mask of shape (n_atoms,).

Return type:

torch.Tensor

Raises:

ValueError – If mode is not one of ‘set’, ‘add’, ‘remove’.

Examples

# Create new mask selecting chain A
mask = create_selection_mask("chain A", pdb_df, mode='set')

# Add residues 10-20 to existing mask
mask = create_selection_mask("resseq 10:20", pdb_df, current_mask=mask, mode='add')

# Remove water from mask
mask = create_selection_mask("resname HOH", pdb_df, current_mask=mask, mode='remove')
class torchref.utils.DebugMixin[source]

Bases: object

Mixin class that adds debugging capabilities to modules.

When an error occurs, call print_debug_summary() to get a comprehensive overview of the module’s state including:

  • All attributes and their types

  • Tensor shapes, dtypes, and devices

  • DataFrame/array shapes

  • Other object information

print_debug_summary(title=None, file=<_io.TextIOWrapper name='<stderr>' mode='w' encoding='utf-8'>)[source]

Print a comprehensive debug summary of this module’s state.

Parameters:
  • title (str, optional) – Title for the summary.

  • file (file-like, default sys.stderr) – File to write output to.

debug_on_error(error, context='', recursive=True)[source]

Print debug summary when an error occurs, recursively printing submodules.

Parameters:
  • error (Exception) – The exception that was caught.

  • context (str, default "") – Additional context string to print.

  • recursive (bool, default True) – If True, recursively print debug info for all submodules.

torchref.utils.print_module_summary(module, title=None, file=<_io.TextIOWrapper name='<stderr>' mode='w' encoding='utf-8'>)[source]

Print debug summary for any module.

Standalone function to print debug information for modules that may or may not have the DebugMixin.

Parameters:
  • module (object) – The module to inspect.

  • title (str, optional) – Title for the summary.

  • file (file-like, default sys.stderr) – File to write output to.

class torchref.utils.StatEntry(value, verbosity=1)[source]

Bases: object

A statistics entry with value and verbosity level.

JSON serializable - when serialized, only the value is written.

value

The statistic value.

Type:

Any

verbosity

Verbosity level required to show this stat.

Type:

int

value: Any
verbosity: int = 1
__json__()[source]

Return JSON-serializable representation (just the value).

__init__(value, verbosity=1)
torchref.utils.stat(value, verbosity=1)[source]

Create a StatEntry with given value and verbosity.

Parameters:
  • value (Any) – The statistic value.

  • verbosity (int, optional) – Verbosity level. Default is VERBOSITY_STANDARD.

Returns:

A statistics entry object.

Return type:

StatEntry

torchref.utils.filter_stats(stats, max_verbosity)[source]

Filter stats dictionary to only include entries at or below max_verbosity.

Parameters:
  • stats (dict) – Stats dictionary with StatEntry values or nested dicts.

  • max_verbosity (int) – Maximum verbosity level to include.

Returns:

Filtered stats with raw values (StatEntry unwrapped).

Return type:

dict

torchref.utils.flatten_stats(stats, prefix='')[source]

Flatten nested stats dict into flat dict with dotted keys.

Parameters:
  • stats (dict) – Nested stats dictionary.

  • prefix (str, optional) – Prefix for keys. Default is ‘’.

Returns:

Flattened dictionary with dotted keys.

Return type:

dict

torchref.utils.format_stats_table(stats, title='', indent=2)[source]

Format stats dictionary as a printable table.

Parameters:
  • stats (dict) – Stats dictionary (already filtered by verbosity).

  • title (str, optional) – Title for the table.

  • indent (int, optional) – Indentation spaces. Default is 2.

Returns:

Formatted table string.

Return type:

str

class torchref.utils.HyperparameterMixin[source]

Bases: object

Mixin class providing hyperparameter registration and tracking.

Use this with nn.Module to add hyperparameter tracking capabilities. Hyperparameters are stored as buffers but tracked separately.

Examples

class MyTarget(HyperparameterMixin, nn.Module):
    def __init__(self, sigma=1.0, target_value=0.0):
        nn.Module.__init__(self)
        HyperparameterMixin.__init__(self)
        self.register_hyperparameter('sigma', sigma)
        self.register_hyperparameter('target_value', target_value)

target = MyTarget(sigma=2.5)
dict(target.hyperparameters())
# {'sigma': tensor(2.5), 'target_value': tensor(0.0)}
__init__()[source]

Initialize hyperparameter tracking.

register_hyperparameter(name, value, persistent=True)[source]

Register a hyperparameter.

Hyperparameters are stored as buffers (so they move with .to(device)) but tracked separately for easy access and state_dict operations.

Parameters:
  • name (str) – Name of the hyperparameter. Will be prefixed with ‘_hp_’ internally.

  • value (float) – Initial value of the hyperparameter.

  • persistent (bool, optional) – If True, hyperparameter will be part of module’s state_dict. Default is True.

Examples

self.register_hyperparameter('sigma', 2.0)
self.sigma  # Access via property (if defined) or _hp_sigma
# 2.0
get_hyperparameter(name)[source]

Get a hyperparameter value.

Parameters:

name (str) – Name of the hyperparameter.

Returns:

The hyperparameter tensor.

Return type:

torch.Tensor

set_hyperparameter(name, value)[source]

Set a hyperparameter value.

Parameters:
  • name (str) – Name of the hyperparameter.

  • value (float) – New value.

hyperparameters(recurse=True)[source]

Return an iterator over module hyperparameters.

Parameters:

recurse (bool, optional) – If True, include hyperparameters from submodules. Default is True.

Yields:

tuple – (name, tensor) pairs for each hyperparameter.

Examples

for name, hp in module.hyperparameters():
    print(f"{name}: {hp.item()}")
named_hyperparameters(prefix='', recurse=True)[source]

Return an iterator over module hyperparameters with names.

Same as hyperparameters() but allows prefix specification.

Parameters:
  • prefix (str, optional) – Prefix to prepend to names. Default is ‘’.

  • recurse (bool, optional) – If True, include hyperparameters from submodules. Default is True.

Yields:

tuple – (name, tensor) pairs for each hyperparameter.

hyperparameter_state_dict(prefix='')[source]

Return a dictionary containing only hyperparameters.

Unlike state_dict() which includes all buffers and parameters, this only returns the hyperparameters.

Parameters:

prefix (str, optional) – Prefix to prepend to names. Default is ‘’.

Returns:

Dictionary mapping hyperparameter names to tensors.

Return type:

dict

Examples

hp_state = module.hyperparameter_state_dict()
torch.save(hp_state, 'hyperparameters.pt')
load_hyperparameter_state_dict(state_dict, strict=True)[source]

Load hyperparameters from a state dict.

Parameters:
  • state_dict (dict) – Dictionary of hyperparameter name -> tensor.

  • strict (bool, optional) – If True, raise error on missing/unexpected keys. Default is True.

Examples

hp_state = torch.load('hyperparameters.pt')
module.load_hyperparameter_state_dict(hp_state)
hyperparameter_dict()[source]

Return hyperparameters as a simple Python dict of floats.

Useful for logging, serialization to JSON, etc.

Returns:

Dictionary mapping hyperparameter names to float values.

Return type:

dict

Examples

params = module.hyperparameter_dict()
import json
json.dumps(params)  # JSON serializable
print_hyperparameters(prefix='')[source]

Print all hyperparameters in a formatted way.

Parameters:

prefix (str, optional) – Prefix for indentation. Default is ‘’.

torchref.utils.convert_to_serializable(obj)[source]

Convert tensors and numpy arrays to JSON-serializable types.

Recursively walks dicts, lists, and tuples, converting torch.Tensor, numpy.ndarray, and numpy scalar types to plain Python objects that json.dump can handle.

Parameters:

obj (object) – Arbitrary Python object (tensor, array, dict, list, scalar, …).

Returns:

A JSON-serializable equivalent.

Return type:

object

torchref.utils.gradnorm(loss, parameters)[source]

Compute the gradient norm of a loss with respect to given parameters.

Performs a backward pass with graph retention and computes the RMS (root mean square) of all gradients concatenated together.

Parameters:
  • loss (torch.Tensor) – The loss tensor to backpropagate.

  • parameters (iterable) – Iterable of model parameters (typically from model.parameters()).

Returns:

The computed RMS gradient norm.

Return type:

float

Notes

Uses retain_graph=True to allow subsequent backward passes. Only includes parameters that have gradients (skips None grads).

Examples

loss = model(input)
grad_norm = gradnorm(loss, model.parameters())
print(f"Gradient norm: {grad_norm:.4f}")
torchref.utils.validate_loss(loss, *, state=None, parameters=None, check_grads=True, context='', raise_on_fail=True, max_full_diagnostics=3)[source]

Check that loss (and optionally grads / parameters) are finite.

Parameters:
  • loss (torch.Tensor) – The scalar loss returned by the closure. Must be a zero-dim or one-element tensor.

  • state (LossState, optional) – If provided, the diagnostic path re-runs state.aggregate(log_values=True) to repopulate per-target losses and formats them via state.format_breakdown(). Safe to omit for closures that don’t use a LossState (e.g. scalers, alignment).

  • parameters (iterable of torch.Tensor, optional) – Parameters to inspect. When check_grads=True, their gradients are checked for finiteness on the fast path. On failure, both parameters and their grads are reported with non-finite entry counts.

  • check_grads (bool, default True) – Check parameter gradients for finiteness after backward(). This is the usual pathology (backward produces NaN even when forward was finite), so leave it on unless a hot benchmark proves it’s costly.

  • context (str, default "") – Short label written into the diagnostic header and returned in the warning / exception message (e.g. "collection_difference_refine"). Also keys the per-context diagnostic budget.

  • raise_on_fail (bool, default True) – If True, raise NonFiniteLossError on failure (strict mode). If False, print a warning and return False — the caller is responsible for rejecting the LBFGS step (e.g. by zeroing grads and returning +inf so strong-Wolfe backtracks).

  • max_full_diagnostics (int, default 3) – Per-context budget for the full per-target breakdown. After this many failures in the same context, only a compact one-line warning is printed. Prevents log flooding when LBFGS bounces around a persistent NaN region. Pass 0 to always print compact.

Returns:

True if everything is finite (happy path), False otherwise. When raise_on_fail=True, a False result raises instead of returning.

Return type:

bool

Raises:

NonFiniteLossError – If raise_on_fail=True and any of loss / grads / params is non-finite.

exception torchref.utils.NonFiniteLossError[source]

Bases: RuntimeError

Raised when a refinement step produces non-finite loss, grads, or params.

torchref.utils.reset_diagnostic_budget(context=None)[source]

Reset the failure counter used to stride full diagnostics.

Parameters:

context (str, optional) – Reset a single context’s counter. If omitted, reset all.

torchref.utils.collect_loss_leaves(losses)[source]

Return the set of leaf nn.Parameter``s that gradient will accumulate into when ``backward() is called on the given loss(es).

Walks the autograd graph from each root tensor’s grad_fn and finds every AccumulateGrad node, collecting its .variable when it is an nn.Parameter.

Multiple roots are unioned via a single shared traversal so that shared subgraphs (e.g. two losses both depending on the same model forward) are walked exactly once.

Parameters:

losses (Tensor | Iterable[Tensor] | Mapping[str, Tensor]) – One or more loss tensors.

Returns:

Leaf parameters that backward would accumulate gradient into. A leaf with requires_grad=False does not appear (no AccumulateGrad node is created for it). Detached subtrees contribute nothing.

Return type:

set of nn.Parameter

Submodules