torchref.utils.loss_validation module
Centralized loss finiteness validator for refinement closures.
Exports
- validate_loss
Check that a loss tensor (and optionally grads / parameters) is finite. On failure, dumps a per-target breakdown via
LossState.format_breakdown. Supports two modes:raise_on_fail=True(default): raiseNonFiniteLossError. Use in contexts where a non-finite loss means the run is broken.raise_on_fail=False: print a warning (full diagnostic on the first few failures, one-line summary thereafter) and returnFalse. The caller is expected to reject the step — e.g. inside an LBFGS closure, zero out gradients and return+infso the strong-Wolfe line search backtracks.
- NonFiniteLossError
Typed exception so callers can distinguish numerical blow-ups from other ``RuntimeError``s.
Usage
Soft-failure closure (recommended for production LBFGS refinement):
from torchref.utils import validate_loss
def closure():
optimizer.zero_grad()
loss = state.aggregate()
loss.backward()
ok = validate_loss(
loss, state=state, parameters=params,
context="lbfgs", raise_on_fail=False,
)
if not ok:
# Tell LBFGS this step is invalid: zero grads + return +inf,
# the strong-Wolfe line search will backtrack.
for p in params:
if p.grad is not None:
p.grad.zero_()
return torch.full_like(loss, float("inf"))
return loss
Fast path is one GPU→CPU sync on torch.isfinite(loss) plus a single
reduction over parameter gradients when check_grads=True. Diagnostic
path only runs on failure.
- exception torchref.utils.loss_validation.NonFiniteLossError[source]
Bases:
RuntimeErrorRaised when a refinement step produces non-finite loss, grads, or params.
- torchref.utils.loss_validation.reset_diagnostic_budget(context=None)[source]
Reset the failure counter used to stride full diagnostics.
- Parameters:
context (str, optional) – Reset a single context’s counter. If omitted, reset all.
- torchref.utils.loss_validation.validate_loss(loss, *, state=None, parameters=None, check_grads=True, context='', raise_on_fail=True, max_full_diagnostics=3)[source]
Check that
loss(and optionally grads / parameters) are finite.- Parameters:
loss (torch.Tensor) – The scalar loss returned by the closure. Must be a zero-dim or one-element tensor.
state (LossState, optional) – If provided, the diagnostic path re-runs
state.aggregate(log_values=True)to repopulate per-target losses and formats them viastate.format_breakdown(). Safe to omit for closures that don’t use a LossState (e.g. scalers, alignment).parameters (iterable of torch.Tensor, optional) – Parameters to inspect. When
check_grads=True, their gradients are checked for finiteness on the fast path. On failure, both parameters and their grads are reported with non-finite entry counts.check_grads (bool, default True) – Check parameter gradients for finiteness after backward(). This is the usual pathology (backward produces NaN even when forward was finite), so leave it on unless a hot benchmark proves it’s costly.
context (str, default "") – Short label written into the diagnostic header and returned in the warning / exception message (e.g.
"collection_difference_refine"). Also keys the per-context diagnostic budget.raise_on_fail (bool, default True) – If True, raise
NonFiniteLossErroron failure (strict mode). If False, print a warning and returnFalse— the caller is responsible for rejecting the LBFGS step (e.g. by zeroing grads and returning +inf so strong-Wolfe backtracks).max_full_diagnostics (int, default 3) – Per-context budget for the full per-target breakdown. After this many failures in the same context, only a compact one-line warning is printed. Prevents log flooding when LBFGS bounces around a persistent NaN region. Pass
0to always print compact.
- Returns:
Trueif everything is finite (happy path),Falseotherwise. Whenraise_on_fail=True, a False result raises instead of returning.- Return type:
- Raises:
NonFiniteLossError – If
raise_on_fail=Trueand any of loss / grads / params is non-finite.