torchref.refinement.loss_state module

LossState - Hierarchical loss computation with lazy evaluation.

Design: - Targets stored as callables, evaluated only on aggregation - Hierarchical naming via ‘/’ separator (e.g., ‘geometry/bond’, ‘adp/simu’) - Weights computed at initialization, not recalculated each pass - Internal history logging for debugging/analysis - Aggregation handled directly in this class - Loss functions with zeroed weight are not evaluated at all. - Implements optimization closures and data handling - Automatically detects parameter root tensors and disables requires_grad on any that the loss touches but the optimizer wasn’t built to update, enabling dynamic caching.

Example:

state = LossState(device) state.register_target(‘xray/work’, xray_work_target) state.register_target(‘geometry/bond’, bond_target) state.set_weight(‘xray’, 1.0) state.set_weight(‘geometry’, 0.5) state.set_weight(‘geometry/bond’, 1.0)

total = state.aggregate() # Evaluates targets, applies hierarchical weights

state.step(optimizer) # Runs optimizer step with closure that validates loss and auto-freezes parameters depending on the loss graph.

class torchref.refinement.loss_state.LossState(device=<factory>, targets=<factory>, weights=<factory>, history=<factory>, _losses=<factory>, _compilable=<factory>, _compiled_aggregate=None, _loss_leaves=<factory>, _resettable_modules=<factory>, meta=<factory>)[source]

Bases: DeviceMixin

Hierarchical loss state with lazy evaluation.

device

Computation device.

Type:

torch.device

targets

Target functions keyed by hierarchical name (e.g., ‘geometry/bond’).

Type:

Dict[str, Callable]

weights

Weights keyed by name. Can be group weights (‘geometry’) or component weights (‘geometry/bond’).

Type:

Dict[str, float]

history

Log of computed values per aggregation call.

Type:

List[Dict]

meta

Model-level data (rwork, rfree, n_atoms, etc.) populated by refinement.

Type:

Dict[str, Any]

device: device
targets: Dict[str, Callable]
weights: Dict[str, float]
history: List[Dict[str, Any]]
meta: Dict[str, Any]
__getitem__(key)[source]

Get value from meta or _losses by key.

Parameters:

key (str) – Key to look up. Checks meta first, then _losses.

Returns:

Value from meta or _losses.

Return type:

Any

Raises:

KeyError – If key not found in either dict.

__contains__(key)[source]

Check if key exists in meta or _losses.

get(key, default=None)[source]

Get value with default fallback.

Parameters:
  • key (str) – Key to look up.

  • default (Any) – Value to return if key not found.

Returns:

Value from meta, _losses, or default.

Return type:

Any

cache_losses(force=False)[source]

Cache all target losses.

Evaluates all registered targets and stores results in _losses.

Parameters:

force (bool) – If True, re-evaluate all targets even if already cached.

Returns:

Self for chaining.

Return type:

LossState

update_meta(data)[source]

Update meta dict with model-level data.

Parameters:

data (Dict[str, Any]) – Data to add to meta.

Returns:

Self for chaining.

Return type:

LossState

register_target(name, target, prefix=None, compile=False, probe=True)[source]

Register a target function.

Automatically detects combined targets (like TotalGeometryTarget, TotalADPTarget) and expands them into their component targets.

Parameters:
  • name (str) – Hierarchical name (e.g., ‘geometry/bond’, ‘adp/simu’).

  • target (Callable) – Function that returns a loss tensor when called. Can also be a combined target with .items() method, which will be auto-expanded.

  • prefix (str, optional) – Prefix to prepend to the name (e.g., ‘model1’ -> ‘model1/geometry/bond’). Useful for registering targets from multiple models in the same state.

  • compile (bool) – If True, mark this target (or all its sub-targets if combined) as eligible for the compiled aggregate closure built by compile_aggregate().

  • probe (bool) – If True (default), run the target’s forward once, walk the autograd graph, and merge the resulting leaf set into self._loss_leaves. The target’s dependencies (model loaded, data attached, etc.) must therefore be in place before registration. Set probe=False to skip — the leaf-set entry for this target will be empty, so step()/run() will not auto-disable any leaves on its account. Useful only for targets whose forward genuinely cannot be called at registration time.

Returns:

Self for chaining.

Return type:

LossState

register_targets(targets, prefix=None, compile=False, probe=True)[source]

Register multiple targets from a component target or dict.

For targets with a .name attribute, uses target.name as the key. For plain callables, uses the dict key.

Parameters:
  • targets (dict) – Dictionary of name -> target mappings.

  • prefix (str, optional) – Prefix to prepend to all target names.

  • compile (bool) – If True, propagate the compile flag to all sub-targets.

  • probe (bool) – Forwarded to register_target().

set_weight(name, weight)[source]

Set a weight value.

Parameters:
  • name (str) – Weight name. Can be a group (‘geometry’) or component (‘geometry/bond’).

  • weight (float) – Weight value.

Returns:

Self for chaining.

Return type:

LossState

set_weights(weights)[source]

Set multiple weights.

get_weight(name, default=1.0)[source]

Get a weight value.

get_effective_weight(name)[source]

Get effective weight for a target, including group weights.

For ‘geometry/bond’, returns: weights[‘geometry’] * weights[‘geometry/bond’] Missing weights default to 1.0.

Parameters:

name (str) – Target name (e.g., ‘geometry/bond’).

Returns:

Product of all hierarchical weights.

Return type:

float

mark_compilable(names)[source]

Mark already-registered targets as eligible for the compiled aggregate.

Parameters:

names (List[str]) – Target keys to mark (must already be registered).

Returns:

Self for chaining.

Return type:

LossState

compile_aggregate(**compile_kwargs)[source]

Build and cache a torch.compile’d closure over all compilable targets.

Must be called after all targets and weights have been registered. Re-call if weights or compilable targets change (or call reset_compiled_aggregate()).

Parameters:

**compile_kwargs – Keyword arguments forwarded to torch.compile. Defaults to fullgraph=False so partial-graph fallback is allowed.

Returns:

Self for chaining.

Return type:

LossState

reset_compiled_aggregate()[source]

Clear the cached compiled closure (e.g. after changing weights).

log(name, value)[source]

Log a value to the current history entry.

Creates a new history entry if needed.

Parameters:
  • name (str) – Key for the logged value.

  • value (Any) – Value to log. Tensors are converted to Python floats.

new_entry()[source]

Start a new history entry.

get_history(name)[source]

Get all logged values for a key across history.

aggregate(log_values=False)[source]

Evaluate all targets and compute weighted sum.

When compile_aggregate() has been called and log_values=False, the compilable targets are evaluated through a single torch.compile’d closure for improved performance. With log_values=True all targets run eagerly so per-target losses are available in _losses.

Parameters:

log_values (bool) – If True, log all losses, weights, and total to history.

Returns:

Total weighted loss.

Return type:

torch.Tensor

get_loss(name)[source]

Get a cached loss value (after aggregate() was called).

Parameters:

name (str) – Target name.

Returns:

Cached loss, or None if not computed.

Return type:

torch.Tensor or None

active_parameters()[source]

Return the set of leaf ``nn.Parameter``s that registered targets’ backward passes will accumulate gradient into.

Populated incrementally by register_target() via a one-shot probe forward + autograd graph walk — calling this method does not run any forward, walk any graph, or evaluate any target. The result is conservative: a target whose weight is later set to 0 still contributes its leaves here, which is harmless for the freezing logic in step() (it can only over-freeze, never under-freeze).

refresh_loss_leaves()[source]

Re-probe every registered target and rebuild _loss_leaves and the resettable-modules cache.

Use this after external code has replaced parameter identity on the underlying model — for example after Model.freeze() / Model.unfreeze() (which rebuild refinable_params tensors). Under normal step()/run() usage no parameter identity ever changes, so this method is rarely needed.

reset_caches()[source]

Call reset_cache() on every registered target’s submodules that expose one. Invoked automatically at the end of step().

restore_loss_leaf_grads()[source]

Unconditionally re-enable requires_grad on every leaf in self._loss_leaves. Called at the end of step() so the next call sees a clean, fully-differentiable model regardless of what state the previous step (or external code) left things in.

run(optimizer, log=False, nsteps=1, *, context='loss_state.step')[source]

Run a single optimizer.step(closure).

Builds the closure, validates each loss for finiteness via torchref.utils.validate_loss(), and on failure zeros the gradients and returns +inf so the strong-Wolfe line search backtracks. Automatically disables requires_grad on any leaf that the loss touches but the optimizer was not constructed with — autograd then prunes those subgraphs from the backward pass.

Technically this should work with all optimzers in pytorch that support closures but it has only been tested for LBFGS so far. The closure is built to be as general as possible, so if you have a custom optimizer that supports closures it should “just work” with this method.

Every collected reset_cache-bearing submodule is reset before the optimizer step so the closure’s first forward sees a clean cache (a previous rejected closure may have stored a NaN/inf forward result that the fingerprint would happily serve again if parameter values haven’t changed).

After the the run we call maintenance on all targets.

On exit, requires_grad=True is unconditionally re-enabled on every leaf in self._loss_leaves — defending against state bleeding between successive refinement methods.

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer to step. Its param_groups define the intent — the leaves the caller actually wants to update.

  • log (bool) – If True, calls aggregate(log_values=True) before and after the optimization loop

  • nsteps (int) – Number of steps to run (default 1). Only the first step’s closure caching is enabled between multiple steps. If you want to run truly independent steps, call this method multiple times with nsteps=1. This adds overhead but might be desirable if the overhead is negligible anyway.

  • context (str) – Diagnostic label forwarded to validate_loss.

Returns:

The loss tensor from the last accepted closure call, or None if no closure call succeeded (every call produced non-finite loss).

Return type:

torch.Tensor or None

step(optimizer, *args, **kwargs)[source]

Convenience method that calls run() with 1 step.

Parameters:
get_breakdown()[source]

Get breakdown of losses by group.

Returns:

Nested dict: {group: {component: {‘loss’: …, ‘weight’: …, ‘weighted’: …}}}

Return type:

Dict

get_group_totals()[source]

Get total weighted loss per group.

Returns:

{group_name: total_weighted_loss}

Return type:

Dict[str, float]

format_breakdown()[source]

Return per-target loss / weight / weighted / finite as a string.

One row per target currently in self._losses (populated by the most recent eager aggregate() call). Used by both summary() and torchref.utils.validate_loss() so the diagnostic format does not drift.

summary()[source]

Print a per-target loss breakdown to stdout.

to(*args, **kwargs)[source]

Move via DeviceMixin; honour an explicit device when no tensors exist yet.

clear()[source]

Clear cached losses (not targets or weights).

clear_history()[source]

Clear history log.

__init__(device=<factory>, targets=<factory>, weights=<factory>, history=<factory>, _losses=<factory>, _compilable=<factory>, _compiled_aggregate=None, _loss_leaves=<factory>, _resettable_modules=<factory>, meta=<factory>)
torchref.refinement.loss_state.create_loss_state(device, targets=None, weights=None)[source]

Factory function to create a LossState.

Parameters:
  • device (torch.device) – Computation device.

  • targets (Dict[str, Callable], optional) – Initial targets to register.

  • weights (Dict[str, float], optional) – Initial weights to set.

Returns:

Configured LossState instance.

Return type:

LossState