torchref.utils.loss_validation module

Centralized loss finiteness validator for refinement closures.

Exports

validate_loss

Check that a loss tensor (and optionally grads / parameters) is finite. On failure, dumps a per-target breakdown via LossState.format_breakdown. Supports two modes:

  • raise_on_fail=True (default): raise NonFiniteLossError. Use in contexts where a non-finite loss means the run is broken.

  • raise_on_fail=False: print a warning (full diagnostic on the first few failures, one-line summary thereafter) and return False. The caller is expected to reject the step — e.g. inside an LBFGS closure, zero out gradients and return +inf so the strong-Wolfe line search backtracks.

NonFiniteLossError

Typed exception so callers can distinguish numerical blow-ups from other ``RuntimeError``s.

Usage

Soft-failure closure (recommended for production LBFGS refinement):

from torchref.utils import validate_loss

def closure():
    optimizer.zero_grad()
    loss = state.aggregate()
    loss.backward()
    ok = validate_loss(
        loss, state=state, parameters=params,
        context="lbfgs", raise_on_fail=False,
    )
    if not ok:
        # Tell LBFGS this step is invalid: zero grads + return +inf,
        # the strong-Wolfe line search will backtrack.
        for p in params:
            if p.grad is not None:
                p.grad.zero_()
        return torch.full_like(loss, float("inf"))
    return loss

Fast path is one GPU→CPU sync on torch.isfinite(loss) plus a single reduction over parameter gradients when check_grads=True. Diagnostic path only runs on failure.

exception torchref.utils.loss_validation.NonFiniteLossError[source]

Bases: RuntimeError

Raised when a refinement step produces non-finite loss, grads, or params.

torchref.utils.loss_validation.reset_diagnostic_budget(context=None)[source]

Reset the failure counter used to stride full diagnostics.

Parameters:

context (str, optional) – Reset a single context’s counter. If omitted, reset all.

torchref.utils.loss_validation.validate_loss(loss, *, state=None, parameters=None, check_grads=True, context='', raise_on_fail=True, max_full_diagnostics=3)[source]

Check that loss (and optionally grads / parameters) are finite.

Parameters:
  • loss (torch.Tensor) – The scalar loss returned by the closure. Must be a zero-dim or one-element tensor.

  • state (LossState, optional) – If provided, the diagnostic path re-runs state.aggregate(log_values=True) to repopulate per-target losses and formats them via state.format_breakdown(). Safe to omit for closures that don’t use a LossState (e.g. scalers, alignment).

  • parameters (iterable of torch.Tensor, optional) – Parameters to inspect. When check_grads=True, their gradients are checked for finiteness on the fast path. On failure, both parameters and their grads are reported with non-finite entry counts.

  • check_grads (bool, default True) – Check parameter gradients for finiteness after backward(). This is the usual pathology (backward produces NaN even when forward was finite), so leave it on unless a hot benchmark proves it’s costly.

  • context (str, default "") – Short label written into the diagnostic header and returned in the warning / exception message (e.g. "collection_difference_refine"). Also keys the per-context diagnostic budget.

  • raise_on_fail (bool, default True) – If True, raise NonFiniteLossError on failure (strict mode). If False, print a warning and return False — the caller is responsible for rejecting the LBFGS step (e.g. by zeroing grads and returning +inf so strong-Wolfe backtracks).

  • max_full_diagnostics (int, default 3) – Per-context budget for the full per-target breakdown. After this many failures in the same context, only a compact one-line warning is printed. Prevents log flooding when LBFGS bounces around a persistent NaN region. Pass 0 to always print compact.

Returns:

True if everything is finite (happy path), False otherwise. When raise_on_fail=True, a False result raises instead of returning.

Return type:

bool

Raises:

NonFiniteLossError – If raise_on_fail=True and any of loss / grads / params is non-finite.