torchref.refinement.loss_state module
LossState - Hierarchical loss computation with lazy evaluation.
Design: - Targets stored as callables, evaluated only on aggregation - Hierarchical naming via ‘/’ separator (e.g., ‘geometry/bond’, ‘adp/simu’) - Weights computed at initialization, not recalculated each pass - Internal history logging for debugging/analysis - Aggregation handled directly in this class - Loss functions with zeroed weight are not evaluated at all. - Implements optimization closures and data handling - Automatically detects parameter root tensors and disables requires_grad on any that the loss touches but the optimizer wasn’t built to update, enabling dynamic caching.
- Example:
state = LossState(device) state.register_target(‘xray/work’, xray_work_target) state.register_target(‘geometry/bond’, bond_target) state.set_weight(‘xray’, 1.0) state.set_weight(‘geometry’, 0.5) state.set_weight(‘geometry/bond’, 1.0)
total = state.aggregate() # Evaluates targets, applies hierarchical weights
state.step(optimizer) # Runs optimizer step with closure that validates loss and auto-freezes parameters depending on the loss graph.
- class torchref.refinement.loss_state.LossState(device=<factory>, targets=<factory>, weights=<factory>, history=<factory>, _losses=<factory>, _compilable=<factory>, _compiled_aggregate=None, _loss_leaves=<factory>, _resettable_modules=<factory>, meta=<factory>)[source]
Bases:
DeviceMixinHierarchical loss state with lazy evaluation.
- device
Computation device.
- Type:
- targets
Target functions keyed by hierarchical name (e.g., ‘geometry/bond’).
- Type:
Dict[str, Callable]
- weights
Weights keyed by name. Can be group weights (‘geometry’) or component weights (‘geometry/bond’).
- history
Log of computed values per aggregation call.
- Type:
List[Dict]
- get(key, default=None)[source]
Get value with default fallback.
- Parameters:
key (str) – Key to look up.
default (Any) – Value to return if key not found.
- Returns:
Value from meta, _losses, or default.
- Return type:
Any
- cache_losses(force=False)[source]
Cache all target losses.
Evaluates all registered targets and stores results in _losses.
- register_target(name, target, prefix=None, compile=False, probe=True)[source]
Register a target function.
Automatically detects combined targets (like TotalGeometryTarget, TotalADPTarget) and expands them into their component targets.
- Parameters:
name (str) – Hierarchical name (e.g., ‘geometry/bond’, ‘adp/simu’).
target (Callable) – Function that returns a loss tensor when called. Can also be a combined target with .items() method, which will be auto-expanded.
prefix (str, optional) – Prefix to prepend to the name (e.g., ‘model1’ -> ‘model1/geometry/bond’). Useful for registering targets from multiple models in the same state.
compile (bool) – If True, mark this target (or all its sub-targets if combined) as eligible for the compiled aggregate closure built by compile_aggregate().
probe (bool) – If True (default), run the target’s forward once, walk the autograd graph, and merge the resulting leaf set into
self._loss_leaves. The target’s dependencies (model loaded, data attached, etc.) must therefore be in place before registration. Setprobe=Falseto skip — the leaf-set entry for this target will be empty, sostep()/run()will not auto-disable any leaves on its account. Useful only for targets whose forward genuinely cannot be called at registration time.
- Returns:
Self for chaining.
- Return type:
- register_targets(targets, prefix=None, compile=False, probe=True)[source]
Register multiple targets from a component target or dict.
For targets with a .name attribute, uses target.name as the key. For plain callables, uses the dict key.
- Parameters:
targets (dict) – Dictionary of name -> target mappings.
prefix (str, optional) – Prefix to prepend to all target names.
compile (bool) – If True, propagate the compile flag to all sub-targets.
probe (bool) – Forwarded to
register_target().
- get_effective_weight(name)[source]
Get effective weight for a target, including group weights.
For ‘geometry/bond’, returns: weights[‘geometry’] * weights[‘geometry/bond’] Missing weights default to 1.0.
- mark_compilable(names)[source]
Mark already-registered targets as eligible for the compiled aggregate.
- compile_aggregate(**compile_kwargs)[source]
Build and cache a torch.compile’d closure over all compilable targets.
Must be called after all targets and weights have been registered. Re-call if weights or compilable targets change (or call reset_compiled_aggregate()).
- Parameters:
**compile_kwargs – Keyword arguments forwarded to torch.compile. Defaults to fullgraph=False so partial-graph fallback is allowed.
- Returns:
Self for chaining.
- Return type:
- reset_compiled_aggregate()[source]
Clear the cached compiled closure (e.g. after changing weights).
- log(name, value)[source]
Log a value to the current history entry.
Creates a new history entry if needed.
- Parameters:
name (str) – Key for the logged value.
value (Any) – Value to log. Tensors are converted to Python floats.
- aggregate(log_values=False)[source]
Evaluate all targets and compute weighted sum.
When compile_aggregate() has been called and log_values=False, the compilable targets are evaluated through a single torch.compile’d closure for improved performance. With log_values=True all targets run eagerly so per-target losses are available in _losses.
- Parameters:
log_values (bool) – If True, log all losses, weights, and total to history.
- Returns:
Total weighted loss.
- Return type:
- get_loss(name)[source]
Get a cached loss value (after aggregate() was called).
- Parameters:
name (str) – Target name.
- Returns:
Cached loss, or None if not computed.
- Return type:
torch.Tensor or None
- active_parameters()[source]
Return the set of leaf ``nn.Parameter``s that registered targets’ backward passes will accumulate gradient into.
Populated incrementally by
register_target()via a one-shot probe forward + autograd graph walk — calling this method does not run any forward, walk any graph, or evaluate any target. The result is conservative: a target whose weight is later set to 0 still contributes its leaves here, which is harmless for the freezing logic instep()(it can only over-freeze, never under-freeze).
- refresh_loss_leaves()[source]
Re-probe every registered target and rebuild
_loss_leavesand the resettable-modules cache.Use this after external code has replaced parameter identity on the underlying model — for example after
Model.freeze()/Model.unfreeze()(which rebuildrefinable_paramstensors). Under normalstep()/run()usage no parameter identity ever changes, so this method is rarely needed.
- reset_caches()[source]
Call
reset_cache()on every registered target’s submodules that expose one. Invoked automatically at the end ofstep().
- restore_loss_leaf_grads()[source]
Unconditionally re-enable
requires_gradon every leaf inself._loss_leaves. Called at the end ofstep()so the next call sees a clean, fully-differentiable model regardless of what state the previous step (or external code) left things in.
- run(optimizer, log=False, nsteps=1, *, context='loss_state.step')[source]
Run a single
optimizer.step(closure).Builds the closure, validates each loss for finiteness via
torchref.utils.validate_loss(), and on failure zeros the gradients and returns+infso the strong-Wolfe line search backtracks. Automatically disablesrequires_gradon any leaf that the loss touches but the optimizer was not constructed with — autograd then prunes those subgraphs from the backward pass.Technically this should work with all optimzers in pytorch that support closures but it has only been tested for LBFGS so far. The closure is built to be as general as possible, so if you have a custom optimizer that supports closures it should “just work” with this method.
Every collected
reset_cache-bearing submodule is reset before the optimizer step so the closure’s first forward sees a clean cache (a previous rejected closure may have stored a NaN/inf forward result that the fingerprint would happily serve again if parameter values haven’t changed).After the the run we call maintenance on all targets.
On exit,
requires_grad=Trueis unconditionally re-enabled on every leaf inself._loss_leaves— defending against state bleeding between successive refinement methods.- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to step. Its
param_groupsdefine the intent — the leaves the caller actually wants to update.log (bool) – If True, calls
aggregate(log_values=True)before and after the optimization loopnsteps (int) – Number of steps to run (default 1). Only the first step’s closure caching is enabled between multiple steps. If you want to run truly independent steps, call this method multiple times with nsteps=1. This adds overhead but might be desirable if the overhead is negligible anyway.
context (str) – Diagnostic label forwarded to
validate_loss.
- Returns:
The loss tensor from the last accepted closure call, or
Noneif no closure call succeeded (every call produced non-finite loss).- Return type:
torch.Tensor or None
- step(optimizer, *args, **kwargs)[source]
Convenience method that calls
run()with 1 step.- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to run.
*args – Forwarded to
run().**kwargs – Forwarded to
run().
- get_breakdown()[source]
Get breakdown of losses by group.
- Returns:
Nested dict: {group: {component: {‘loss’: …, ‘weight’: …, ‘weighted’: …}}}
- Return type:
Dict
- format_breakdown()[source]
Return per-target loss / weight / weighted / finite as a string.
One row per target currently in
self._losses(populated by the most recent eageraggregate()call). Used by bothsummary()andtorchref.utils.validate_loss()so the diagnostic format does not drift.
- to(*args, **kwargs)[source]
Move via
DeviceMixin; honour an explicit device when no tensors exist yet.
- __init__(device=<factory>, targets=<factory>, weights=<factory>, history=<factory>, _losses=<factory>, _compilable=<factory>, _compiled_aggregate=None, _loss_leaves=<factory>, _resettable_modules=<factory>, meta=<factory>)
- torchref.refinement.loss_state.create_loss_state(device, targets=None, weights=None)[source]
Factory function to create a LossState.
- Parameters:
device (torch.device) – Computation device.
targets (Dict[str, Callable], optional) – Initial targets to register.
weights (Dict[str, float], optional) – Initial weights to set.
- Returns:
Configured LossState instance.
- Return type: