"""
LossState - Hierarchical loss computation with lazy evaluation.
Design:
- Targets stored as callables, evaluated only on aggregation
- Hierarchical naming via '/' separator (e.g., 'geometry/bond', 'adp/simu')
- Weights computed at initialization, not recalculated each pass
- Internal history logging for debugging/analysis
- Aggregation handled directly in this class
- Loss functions with zeroed weight are not evaluated at all.
- Implements optimization closures and data handling
- Automatically detects parameter root tensors and disables requires_grad on any that the loss touches but the optimizer wasn't built to update, enabling dynamic caching.
Example:
state = LossState(device)
state.register_target('xray/work', xray_work_target)
state.register_target('geometry/bond', bond_target)
state.set_weight('xray', 1.0)
state.set_weight('geometry', 0.5)
state.set_weight('geometry/bond', 1.0)
total = state.aggregate() # Evaluates targets, applies hierarchical weights
state.step(optimizer) # Runs optimizer step with closure that validates loss and auto-freezes parameters depending on the loss graph.
"""
import warnings
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set
import torch
from torch import nn
from torchref.config import get_default_device
from torchref.utils.autograd_introspection import collect_loss_leaves, _iter_roots
from torchref.utils.device_mixin import DeviceMovementMixin
from torchref.utils.loss_validation import validate_loss
class LossStateWarning(UserWarning):
"""Performance hints emitted by :class:`LossState`.
Subclassed from ``UserWarning`` so it shows up by default, but exposed
as a distinct category so callers can silence/escalate it independently.
"""
[docs]
@dataclass
class LossState(DeviceMovementMixin):
"""
Hierarchical loss state with lazy evaluation.
Attributes
----------
device : torch.device
Computation device.
targets : Dict[str, Callable]
Target functions keyed by hierarchical name (e.g., 'geometry/bond').
weights : Dict[str, float]
Weights keyed by name. Can be group weights ('geometry') or
component weights ('geometry/bond').
history : List[Dict]
Log of computed values per aggregation call.
meta : Dict[str, Any]
Model-level data (rwork, rfree, n_atoms, etc.) populated by refinement.
"""
device: torch.device = field(default_factory=get_default_device)
# Targets as callables - only evaluated on aggregate()
targets: Dict[str, Callable] = field(default_factory=dict)
# Weights - computed at init, hierarchical via naming
weights: Dict[str, float] = field(default_factory=dict)
# History log
history: List[Dict[str, Any]] = field(default_factory=list)
# Cache for computed losses (cleared on each aggregate)
_losses: Dict[str, torch.Tensor] = field(default_factory=dict, repr=False)
# Set of target keys marked as compilable
_compilable: Set[str] = field(default_factory=set, repr=False)
# Cached compiled callable; None until compile_aggregate() is called
_compiled_aggregate: Optional[Callable] = field(default=None, repr=False)
# Union of leaf nn.Parameters that registered targets' backward will
# accumulate into. Populated incrementally during register_target via a
# one-shot probe forward + autograd graph walk. Used by step()/run() to
# diff against the optimizer's intent and disable requires_grad on the
# leaves the loss touches but the optimizer wasn't built to update.
_loss_leaves: Set[nn.Parameter] = field(default_factory=set, repr=False)
# Submodules attached to registered targets that expose a reset_cache
# method (e.g. ModelFT and its CachedForwardMixin wrappers). Collected
# once at registration time and reset after every step() so that
# validate_loss-rejected closures or stale forward-cache entries can't
# silently poison the next forward.
_resettable_modules: List[nn.Module] = field(default_factory=list, repr=False)
# Model-level data for weighting schemes
meta: Dict[str, Any] = field(default_factory=dict)
# =========================================================================
# Item Access (meta and _losses)
# =========================================================================
[docs]
def __getitem__(self, key: str) -> Any:
"""
Get value from meta or _losses by key.
Parameters
----------
key : str
Key to look up. Checks meta first, then _losses.
Returns
-------
Any
Value from meta or _losses.
Raises
------
KeyError
If key not found in either dict.
"""
if key in self.meta:
return self.meta[key]
if key in self._losses:
return self._losses[key]
raise KeyError(f"Key '{key}' not found in meta or _losses")
[docs]
def __contains__(self, key: str) -> bool:
"""Check if key exists in meta or _losses."""
return key in self.meta or key in self._losses
[docs]
def get(self, key: str, default: Any = None) -> Any:
"""
Get value with default fallback.
Parameters
----------
key : str
Key to look up.
default : Any
Value to return if key not found.
Returns
-------
Any
Value from meta, _losses, or default.
"""
if key in self.meta:
return self.meta[key]
if key in self._losses:
return self._losses[key]
return default
[docs]
def cache_losses(self, force: bool = False) -> "LossState":
"""
Cache all target losses.
Evaluates all registered targets and stores results in _losses.
Parameters
----------
force : bool
If True, re-evaluate all targets even if already cached.
Returns
-------
LossState
Self for chaining.
"""
if force:
self._losses.clear()
for name, target in self.targets.items():
if name not in self._losses:
self._losses[name] = target()
return self
# =========================================================================
# Target Registration
# =========================================================================
[docs]
def register_target(
self,
name: str,
target: Callable,
prefix: str = None,
compile: bool = False,
probe: bool = True,
) -> "LossState":
"""
Register a target function.
Automatically detects combined targets (like TotalGeometryTarget,
TotalADPTarget) and expands them into their component targets.
Parameters
----------
name : str
Hierarchical name (e.g., 'geometry/bond', 'adp/simu').
target : Callable
Function that returns a loss tensor when called. Can also be a
combined target with .items() method, which will be auto-expanded.
prefix : str, optional
Prefix to prepend to the name (e.g., 'model1' -> 'model1/geometry/bond').
Useful for registering targets from multiple models in the same state.
compile : bool
If True, mark this target (or all its sub-targets if combined) as
eligible for the compiled aggregate closure built by compile_aggregate().
probe : bool
If True (default), run the target's forward once, walk the
autograd graph, and merge the resulting leaf set into
``self._loss_leaves``. The target's dependencies (model loaded,
data attached, etc.) must therefore be in place before
registration. Set ``probe=False`` to skip — the leaf-set entry
for this target will be empty, so :meth:`step`/:meth:`run` will
not auto-disable any leaves on its account. Useful only for
targets whose forward genuinely cannot be called at registration
time.
Returns
-------
LossState
Self for chaining.
"""
self._compiled_aggregate = None # invalidate stale compiled closure
# Check if target is a combined/dictionary-like target with .items()
# This handles CombinedTargets, TotalGeometryTarget, TotalADPTarget, etc.
if hasattr(target, 'items') and callable(getattr(target, 'items', None)):
# Use name as prefix to maintain hierarchy (e.g., "geometry" -> "geometry/bond")
combined_prefix = f"{prefix}/{name}" if prefix else name
return self.register_targets(
target, prefix=combined_prefix, compile=compile, probe=probe
)
# Normal single target registration
key = f"{prefix}/{name}" if prefix else name
self.targets[key] = target
if compile:
self._compilable.add(key)
if probe:
self._probe_and_merge_leaves(target)
self._collect_resettable_modules(target)
return self
def _collect_resettable_modules(self, target: Callable) -> None:
"""Walk ``target``'s submodules and collect any that expose a
``reset_cache`` method, deduplicating against modules already
captured from earlier registrations.
These modules are reset after every :meth:`step` call so that a
``validate_loss``-rejected closure or a stale forward cache cannot
silently poison the next aggregate.
"""
if not isinstance(target, nn.Module):
return
# Use object identity for dedup; nn.Modules aren't hashable by
# default in a way that matches identity, so iterate.
seen_ids = {id(m) for m in self._resettable_modules}
for module in target.modules():
method = getattr(module, "reset_cache", None)
if callable(method) and id(module) not in seen_ids:
self._resettable_modules.append(module)
seen_ids.add(id(module))
def _probe_and_merge_leaves(self, target: Callable) -> None:
"""Run ``target()`` once with grad enabled, walk the autograd graph,
and union the resulting leaves into ``self._loss_leaves``.
``target()`` may return a Tensor, a tuple/list of tensors, or a dict
of tensors — :func:`collect_loss_leaves` handles all three.
Convention: targets are expected to be probed while every parameter
the loss should track has ``requires_grad=True``. If the probe walks
the full graph and finds zero leaves — meaning every root is either
constant (``grad_fn is None``) or only depends on tensors with
``requires_grad=False`` — emit a warning. This is almost never what
the caller intended (a target with no trainable parameters
contributes nothing to gradient-based optimization). It is *not*
an error: missing leaves only cost a bit of extra backward work
inside the optimization loop, never correctness.
"""
with torch.enable_grad():
roots = target()
new_leaves = collect_loss_leaves(roots)
self._loss_leaves |= new_leaves
[docs]
def register_targets(
self,
targets,
prefix: str = None,
compile: bool = False,
probe: bool = True,
) -> "LossState":
"""Register multiple targets from a component target or dict.
For targets with a .name attribute, uses target.name as the key.
For plain callables, uses the dict key.
Parameters
----------
targets : dict
Dictionary of name -> target mappings.
prefix : str, optional
Prefix to prepend to all target names.
compile : bool
If True, propagate the compile flag to all sub-targets.
probe : bool
Forwarded to :meth:`register_target`.
"""
for name, target in targets.items():
target_name = getattr(target, "name", name)
self.register_target(
target_name, target, prefix=prefix, compile=compile, probe=probe
)
return self
# =========================================================================
# Weight Management
# =========================================================================
[docs]
def set_weight(self, name: str, weight: float) -> "LossState":
"""
Set a weight value.
Parameters
----------
name : str
Weight name. Can be a group ('geometry') or component ('geometry/bond').
weight : float
Weight value.
Returns
-------
LossState
Self for chaining.
"""
self.weights[name] = weight
self._compiled_aggregate = None # invalidate stale compiled closure (weights baked in)
return self
[docs]
def set_weights(self, weights: Dict[str, float]) -> "LossState":
"""Set multiple weights."""
for name, weight in weights.items():
self.set_weight(name, weight)
return self
[docs]
def get_weight(self, name: str, default: float = 1.0) -> float:
"""Get a weight value."""
return self.weights.get(name, default)
[docs]
def get_effective_weight(self, name: str) -> float:
"""
Get effective weight for a target, including group weights.
For 'geometry/bond', returns: weights['geometry'] * weights['geometry/bond']
Missing weights default to 1.0.
Parameters
----------
name : str
Target name (e.g., 'geometry/bond').
Returns
-------
float
Product of all hierarchical weights.
"""
parts = name.split("/")
effective = 1.0
# Apply weights at each level
path = ""
for part in parts:
path = f"{path}/{part}" if path else part
effective *= self.weights.get(path, 1.0)
return effective
# =========================================================================
# Compiled Aggregate
# =========================================================================
[docs]
def mark_compilable(self, names: List[str]) -> "LossState":
"""
Mark already-registered targets as eligible for the compiled aggregate.
Parameters
----------
names : List[str]
Target keys to mark (must already be registered).
Returns
-------
LossState
Self for chaining.
"""
for name in names:
if name in self.targets:
self._compilable.add(name)
self._compiled_aggregate = None
return self
[docs]
def compile_aggregate(self, **compile_kwargs) -> "LossState":
"""
Build and cache a torch.compile'd closure over all compilable targets.
Must be called after all targets and weights have been registered.
Re-call if weights or compilable targets change (or call reset_compiled_aggregate()).
Parameters
----------
**compile_kwargs
Keyword arguments forwarded to torch.compile. Defaults to
fullgraph=False so partial-graph fallback is allowed.
Returns
-------
LossState
Self for chaining.
"""
compile_kwargs.setdefault("fullgraph", False)
active = [
(self.targets[n], self.get_effective_weight(n))
for n in self.targets
if n in self._compilable and self.get_effective_weight(n) != 0.0
]
if not active:
self._compiled_aggregate = None
return self
fns, weights = zip(*active)
fns, weights = list(fns), list(weights)
device = self.device
def _compiled_fn():
total = torch.tensor(0.0, device=device)
for fn, w in zip(fns, weights):
total = total + w * fn()
return total
self._compiled_aggregate = torch.compile(_compiled_fn, **compile_kwargs)
return self
[docs]
def reset_compiled_aggregate(self) -> "LossState":
"""Clear the cached compiled closure (e.g. after changing weights)."""
self._compiled_aggregate = None
return self
# =========================================================================
# History Logging
# =========================================================================
[docs]
def log(self, name: str, value: Any) -> None:
"""
Log a value to the current history entry.
Creates a new history entry if needed.
Parameters
----------
name : str
Key for the logged value.
value : Any
Value to log. Tensors are converted to Python floats.
"""
# Ensure we have a current entry
if not self.history:
self.history.append({})
# Convert tensor to float
if isinstance(value, torch.Tensor):
value = value.detach().item()
self.history[-1][name] = value
[docs]
def new_entry(self) -> None:
"""Start a new history entry."""
self.history.append({})
[docs]
def get_history(self, name: str) -> List[Any]:
"""Get all logged values for a key across history."""
return [entry.get(name) for entry in self.history if name in entry]
# =========================================================================
# Aggregation
# =========================================================================
[docs]
def aggregate(self, log_values: bool = False) -> torch.Tensor:
"""
Evaluate all targets and compute weighted sum.
When compile_aggregate() has been called and log_values=False, the
compilable targets are evaluated through a single torch.compile'd closure
for improved performance. With log_values=True all targets run eagerly
so per-target losses are available in _losses.
Parameters
----------
log_values : bool
If True, log all losses, weights, and total to history.
Returns
-------
torch.Tensor
Total weighted loss.
"""
if log_values:
self.new_entry()
self._losses.clear()
total = torch.tensor(0.0, device=self.device)
# --- compiled group ---
# Skipped when log_values=True: the fused closure does not expose
# per-target losses needed for logging.
if self._compiled_aggregate is not None and not log_values:
total = total + self._compiled_aggregate()
else:
# Run compilable targets eagerly (log_values path or no compiled fn)
for name in self._compilable:
if name not in self.targets:
continue
weight = self.get_effective_weight(name)
if weight == 0.0:
continue
loss = self.targets[name]()
self._losses[name] = loss
weighted = weight * loss
total = total + weighted
if log_values:
self.log(f"loss/{name}", loss)
self.log(f"weight/{name}", weight)
self.log(f"weighted/{name}", weighted)
# --- eager group (non-compilable) ---
for name, target in self.targets.items():
if name in self._compilable:
continue # already handled above
weight = self.get_effective_weight(name)
if weight == 0.0:
continue
loss = target()
self._losses[name] = loss
weighted = weight * loss
total = total + weighted
if log_values:
self.log(f"loss/{name}", loss)
self.log(f"weight/{name}", weight)
self.log(f"weighted/{name}", weighted)
if log_values:
self.log("total", total)
return total
[docs]
def get_loss(self, name: str) -> Optional[torch.Tensor]:
"""
Get a cached loss value (after aggregate() was called).
Parameters
----------
name : str
Target name.
Returns
-------
torch.Tensor or None
Cached loss, or None if not computed.
"""
return self._losses.get(name)
# =========================================================================
# Optimization (step / run / active-parameter introspection)
# =========================================================================
[docs]
def active_parameters(self) -> Set[nn.Parameter]:
"""Return the set of leaf ``nn.Parameter``s that registered targets'
backward passes will accumulate gradient into.
Populated incrementally by :meth:`register_target` via a one-shot
probe forward + autograd graph walk — calling this method does not
run any forward, walk any graph, or evaluate any target. The result
is conservative: a target whose weight is later set to 0 still
contributes its leaves here, which is harmless for the freezing
logic in :meth:`step` (it can only over-freeze, never under-freeze).
"""
return self._loss_leaves
[docs]
def refresh_loss_leaves(self) -> "LossState":
"""Re-probe every registered target and rebuild ``_loss_leaves``
and the resettable-modules cache.
Use this after external code has replaced parameter identity on the
underlying model — for example after :meth:`Model.freeze` /
:meth:`Model.unfreeze` (which rebuild ``refinable_params`` tensors).
Under normal :meth:`step`/:meth:`run` usage no parameter identity
ever changes, so this method is rarely needed.
"""
self._loss_leaves = set()
self._resettable_modules = []
for target in self.targets.values():
self._probe_and_merge_leaves(target)
self._collect_resettable_modules(target)
return self
[docs]
def reset_caches(self) -> None:
"""Call ``reset_cache()`` on every registered target's submodules
that expose one. Invoked automatically at the end of :meth:`step`.
"""
for module in self._resettable_modules:
module.reset_cache()
[docs]
def restore_loss_leaf_grads(self) -> None:
"""Unconditionally re-enable ``requires_grad`` on every leaf in
``self._loss_leaves``. Called at the end of :meth:`step` so the
next call sees a clean, fully-differentiable model regardless of
what state the previous step (or external code) left things in.
"""
for p in self._loss_leaves:
if not p.requires_grad:
p.requires_grad_(True)
[docs]
def run(
self,
optimizer: torch.optim.Optimizer,
log=False,
nsteps: int = 1,
*,
context: str = "loss_state.step",
) -> Optional[torch.Tensor]:
"""Run a single ``optimizer.step(closure)``.
Builds the closure, validates each loss for finiteness via
:func:`torchref.utils.validate_loss`, and on failure zeros the
gradients and returns ``+inf`` so the strong-Wolfe line search
backtracks. Automatically disables ``requires_grad`` on any leaf
that the loss touches but the optimizer was not constructed with —
autograd then prunes those subgraphs from the backward pass.
Technically this should work with all optimzers in pytorch that support closures but it has only been tested for LBFGS so far. The closure is built to be as general as possible, so if you have a custom optimizer that supports closures it should "just work" with this method.
Every collected ``reset_cache``-bearing submodule is reset
**before** the optimizer step so the closure's first forward sees
a clean cache (a previous rejected closure may have stored a
NaN/inf forward result that the fingerprint would happily serve
again if parameter values haven't changed).
After the the run we call maintenance on all targets.
On exit, ``requires_grad=True`` is unconditionally re-enabled on
every leaf in ``self._loss_leaves`` — defending against state
bleeding between successive refinement methods.
Parameters
----------
optimizer : torch.optim.Optimizer
Optimizer to step. Its ``param_groups`` define the *intent* —
the leaves the caller actually wants to update.
log : bool
If True, calls ``aggregate(log_values=True)`` before and after the optimization loop
nsteps : int
Number of steps to run (default 1). Only the first step's closure caching is enabled between multiple steps.
If you want to run truly independent steps, call this method multiple times with nsteps=1. This adds overhead but might be desirable if the overhead is negligible anyway.
context : str
Diagnostic label forwarded to ``validate_loss``.
Returns
-------
torch.Tensor or None
The loss tensor from the last accepted closure call, or ``None``
if no closure call succeeded (every call produced non-finite
loss).
"""
params = list(_optimizer_param_set(optimizer))
last_loss: Dict[str, Optional[torch.Tensor]] = {"val": None}
def closure():
optimizer.zero_grad()
loss = self.aggregate()
loss.backward()
ok = validate_loss(
loss,
state=self,
parameters=params,
context=context,
raise_on_fail=False,
)
if not ok:
for p in params:
if p.grad is not None:
p.grad.zero_()
return torch.full_like(loss.detach(), float("inf"))
last_loss["val"] = loss
return loss
if log:
self.aggregate(log_values=True)
# Clear forward caches BEFORE the step so the closure's first
# forward starts from a known-clean state — a previous rejected
# closure may have left a NaN/inf cached fcalc that the fingerprint
# would otherwise serve again unchanged. This helps with robustness but "should" not be necessary.
self.reset_caches()
try:
with _freeze_graph_extras(self, optimizer):
for i in range(nsteps):
optimizer.step(closure)
finally:
# Re-enable grads on every loss leaf regardless of how the
# step exited. Defends against state bleeding between
# successive refinement methods.
self.restore_loss_leaf_grads()
# Post-step maintenance hook: each target decides whether its
# internal state is stale (e.g. NonBondedTarget rebuilds the VDW
# pair list when atoms have drifted too far since the last
# build). Targets that don't care inherit the no-op default
# from ``Target.maintenance``.
for target in self.targets.values():
maint = getattr(target, "maintenance", None)
if callable(maint):
maint()
if log:
self.aggregate(log_values=True)
return last_loss["val"]
[docs]
def step(
self,
optimizer: torch.optim.Optimizer,
*args, **kwargs
) -> "LossState":
"""Convenience method that calls :meth:`run` with 1 step.
Parameters
----------
optimizer : torch.optim.Optimizer
Optimizer to run.
*args, **kwargs
Forwarded to :meth:`run`.
"""
return self.run(optimizer, *args, nsteps=1, **kwargs)
# =========================================================================
# Breakdown / Analysis
# =========================================================================
[docs]
def get_breakdown(self) -> Dict[str, Dict[str, Any]]:
"""
Get breakdown of losses by group.
Returns
-------
Dict
Nested dict: {group: {component: {'loss': ..., 'weight': ..., 'weighted': ...}}}
"""
breakdown = defaultdict(dict)
for name, loss in self._losses.items():
parts = name.split("/")
group = parts[0] if len(parts) > 1 else "root"
component = "/".join(parts[1:]) if len(parts) > 1 else parts[0]
weight = self.get_effective_weight(name)
breakdown[group][component] = {
"loss": loss.item() if isinstance(loss, torch.Tensor) else loss,
"weight": weight,
"weighted": (
(weight * loss).item()
if isinstance(loss, torch.Tensor)
else weight * loss
),
}
return dict(breakdown)
[docs]
def get_group_totals(self) -> Dict[str, float]:
"""
Get total weighted loss per group.
Returns
-------
Dict[str, float]
{group_name: total_weighted_loss}
"""
totals = defaultdict(float)
for name, loss in self._losses.items():
parts = name.split("/")
group = parts[0]
weight = self.get_effective_weight(name)
weighted = (
(weight * loss).item()
if isinstance(loss, torch.Tensor)
else weight * loss
)
totals[group] += weighted
return dict(totals)
[docs]
def summary(self) -> None:
"""Print a per-target loss breakdown to stdout."""
print("LossState Summary:")
print(self.format_breakdown())
# =========================================================================
# Device Management
# =========================================================================
[docs]
def to(self, *args, **kwargs):
"""Move via :class:`DeviceMixin`; honour an explicit device when no tensors exist yet."""
result = super().to(*args, **kwargs)
# If no tensor was found to refresh ``self.device``, fall back to the
# explicit device argument so subsequent allocations land correctly.
if not isinstance(result.device, torch.device):
from torchref.utils.device_mixin import _parse_to_args
device, _ = _parse_to_args(args, kwargs)
if device is not None:
result.device = torch.device(device) if not isinstance(device, torch.device) else device
return result
# =========================================================================
# Utility
# =========================================================================
[docs]
def clear(self) -> "LossState":
"""Clear cached losses (not targets or weights)."""
self._losses.clear()
return self
[docs]
def clear_history(self) -> "LossState":
"""Clear history log."""
self.history.clear()
return self
def __repr__(self) -> str:
n_targets = len(self.targets)
n_weights = len(self.weights)
n_history = len(self.history)
n_meta = len(self.meta)
return f"LossState(device={self.device}, targets={n_targets}, weights={n_weights}, meta={n_meta}, history={n_history})"
def _optimizer_param_set(optimizer: torch.optim.Optimizer) -> Set[nn.Parameter]:
"""Flatten an optimizer's param_groups into a set."""
return {p for g in optimizer.param_groups for p in g["params"]}
@contextmanager
def _freeze_graph_extras(state: "LossState", optimizer: torch.optim.Optimizer):
"""Disable ``requires_grad`` on leaves that ``state`` touches but
``optimizer`` was not constructed with.
Reads the cached leaf union via :meth:`LossState.active_parameters` —
no probe forward is run here. Restoration is handled by the enclosing
:meth:`LossState.step`'s ``finally`` clause, which unconditionally
re-enables ``requires_grad`` on every leaf in ``self._loss_leaves``
(not just the ones we disabled here). That avoids subtle bugs where a
pre-frozen leaf or a leaf disabled by an unrelated code path leaks
into the next step.
"""
intended = _optimizer_param_set(optimizer)
for p in state.active_parameters():
if p not in intended and p.requires_grad:
p.requires_grad_(False)
yield
[docs]
def create_loss_state(
device: torch.device,
targets: Dict[str, Callable] = None,
weights: Dict[str, float] = None,
) -> LossState:
"""
Factory function to create a LossState.
Parameters
----------
device : torch.device
Computation device.
targets : Dict[str, Callable], optional
Initial targets to register.
weights : Dict[str, float], optional
Initial weights to set.
Returns
-------
LossState
Configured LossState instance.
"""
state = LossState(device=device)
if targets:
state.register_targets(targets)
if weights:
state.set_weights(weights)
return state
__all__ = ["LossState", "create_loss_state"]