"""
Centralized loss finiteness validator for refinement closures.
Exports
-------
validate_loss
Check that a loss tensor (and optionally grads / parameters) is finite.
On failure, dumps a per-target breakdown via ``LossState.format_breakdown``.
Supports two modes:
- ``raise_on_fail=True`` (default): raise :class:`NonFiniteLossError`.
Use in contexts where a non-finite loss means the run is broken.
- ``raise_on_fail=False``: print a warning (full diagnostic on the
first few failures, one-line summary thereafter) and return ``False``.
The caller is expected to reject the step — e.g. inside an LBFGS
closure, zero out gradients and return ``+inf`` so the strong-Wolfe
line search backtracks.
NonFiniteLossError
Typed exception so callers can distinguish numerical blow-ups from
other ``RuntimeError``s.
Usage
-----
Soft-failure closure (recommended for production LBFGS refinement)::
from torchref.utils import validate_loss
def closure():
optimizer.zero_grad()
loss = state.aggregate()
loss.backward()
ok = validate_loss(
loss, state=state, parameters=params,
context="lbfgs", raise_on_fail=False,
)
if not ok:
# Tell LBFGS this step is invalid: zero grads + return +inf,
# the strong-Wolfe line search will backtrack.
for p in params:
if p.grad is not None:
p.grad.zero_()
return torch.full_like(loss, float("inf"))
return loss
Fast path is one GPU→CPU sync on ``torch.isfinite(loss)`` plus a single
reduction over parameter gradients when ``check_grads=True``. Diagnostic
path only runs on failure.
"""
from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, List, Optional
import torch
if TYPE_CHECKING:
from torchref.refinement.loss_state import LossState
[docs]
class NonFiniteLossError(RuntimeError):
"""Raised when a refinement step produces non-finite loss, grads, or params."""
# Per-context diagnostic budget so a pathological run doesn't flood logs.
_DIAGNOSTIC_COUNTS: "defaultdict[str, int]" = defaultdict(int)
[docs]
def reset_diagnostic_budget(context: Optional[str] = None) -> None:
"""Reset the failure counter used to stride full diagnostics.
Parameters
----------
context : str, optional
Reset a single context's counter. If omitted, reset all.
"""
if context is None:
_DIAGNOSTIC_COUNTS.clear()
else:
_DIAGNOSTIC_COUNTS.pop(context, None)
def _is_finite_scalar(t: torch.Tensor) -> bool:
"""Sync a zero-dim finiteness check to a Python bool."""
return bool(torch.isfinite(t).all().item())
def _any_nonfinite_grads(parameters: List[torch.Tensor]) -> bool:
"""Check whether any parameter gradient contains a non-finite entry.
Uses ``torch._foreach_isfinite`` when available to batch the check
across all gradient tensors with a single dispatch, then reduces to
one scalar sync.
"""
grads = [p.grad for p in parameters if p.grad is not None]
if not grads:
return False
# Single-sync reduction: any non-finite across all grads.
bad = torch.zeros((), dtype=torch.bool, device=grads[0].device)
for g in grads:
bad = bad | (~torch.isfinite(g).all())
return bool(bad.item())
def _nonfinite_counts(parameters: List[torch.Tensor]) -> List[tuple]:
"""(name_like_shape, non_finite_count) per parameter, for diagnostics only."""
out = []
for i, p in enumerate(parameters):
label = f"param[{i}] shape={tuple(p.shape)}"
n_bad = int((~torch.isfinite(p)).sum().item())
out.append((label, n_bad))
if p.grad is not None:
n_bad_g = int((~torch.isfinite(p.grad)).sum().item())
out.append((f" .grad shape={tuple(p.grad.shape)}", n_bad_g))
return out
[docs]
def validate_loss(
loss: torch.Tensor,
*,
state: Optional["LossState"] = None,
parameters: Optional[Iterable[torch.Tensor]] = None,
check_grads: bool = True,
context: str = "",
raise_on_fail: bool = True,
max_full_diagnostics: int = 3,
) -> bool:
"""Check that ``loss`` (and optionally grads / parameters) are finite.
Parameters
----------
loss : torch.Tensor
The scalar loss returned by the closure. Must be a zero-dim or
one-element tensor.
state : LossState, optional
If provided, the diagnostic path re-runs
``state.aggregate(log_values=True)`` to repopulate per-target losses
and formats them via ``state.format_breakdown()``. Safe to omit for
closures that don't use a LossState (e.g. scalers, alignment).
parameters : iterable of torch.Tensor, optional
Parameters to inspect. When ``check_grads=True``, their gradients
are checked for finiteness on the fast path. On failure, both
parameters and their grads are reported with non-finite entry counts.
check_grads : bool, default True
Check parameter gradients for finiteness after backward(). This is
the usual pathology (backward produces NaN even when forward was
finite), so leave it on unless a hot benchmark proves it's costly.
context : str, default ""
Short label written into the diagnostic header and returned in the
warning / exception message (e.g. ``"collection_difference_refine"``).
Also keys the per-context diagnostic budget.
raise_on_fail : bool, default True
If True, raise :class:`NonFiniteLossError` on failure (strict mode).
If False, print a warning and return ``False`` — the caller is
responsible for rejecting the LBFGS step (e.g. by zeroing grads and
returning +inf so strong-Wolfe backtracks).
max_full_diagnostics : int, default 3
Per-context budget for the full per-target breakdown. After this
many failures in the same context, only a compact one-line warning
is printed. Prevents log flooding when LBFGS bounces around a
persistent NaN region. Pass ``0`` to always print compact.
Returns
-------
bool
``True`` if everything is finite (happy path), ``False`` otherwise.
When ``raise_on_fail=True``, a False result raises instead of
returning.
Raises
------
NonFiniteLossError
If ``raise_on_fail=True`` and any of loss / grads / params is
non-finite.
"""
if not torch.is_tensor(loss):
raise TypeError(f"validate_loss: expected tensor, got {type(loss)!r}")
params_list: List[torch.Tensor] = (
list(parameters) if parameters is not None else []
)
# ---- Fast path: scalar finiteness check(s) ----
loss_finite = _is_finite_scalar(loss)
grads_finite = True
if loss_finite and check_grads and params_list:
grads_finite = not _any_nonfinite_grads(params_list)
if loss_finite and grads_finite:
return True # happy path
# ---- Diagnostic path ----
budget_key = context or "<unnamed>"
_DIAGNOSTIC_COUNTS[budget_key] += 1
count = _DIAGNOSTIC_COUNTS[budget_key]
full = count <= max_full_diagnostics
_emit_diagnostic(
loss=loss,
state=state,
parameters=params_list,
context=context,
loss_finite=loss_finite,
grads_finite=grads_finite,
full=full,
count=count,
)
short = _short_message(context, loss_finite, grads_finite)
if raise_on_fail:
raise NonFiniteLossError(short)
return False
def _short_message(context: str, loss_finite: bool, grads_finite: bool) -> str:
msg = "Non-finite loss detected"
if context:
msg += f" in '{context}'"
msg += (
f" (loss_finite={loss_finite}, grads_finite={grads_finite})."
" See diagnostic block above for per-target breakdown."
)
return msg
def _emit_diagnostic(
*,
loss: torch.Tensor,
state: Optional["LossState"],
parameters: List[torch.Tensor],
context: str,
loss_finite: bool,
grads_finite: bool,
full: bool,
count: int,
) -> None:
"""Print a diagnostic block (full or compact) without raising.
When ``full=False`` (budget exhausted for this context), emit a single
WARN line so the log doesn't flood when LBFGS bounces around a
pathological region for many consecutive line-search probes.
"""
try:
loss_val = loss.item()
except Exception:
loss_val = float("nan")
if not full:
print(
f"WARN[{count:>4}] non-finite loss in '{context}' "
f"(loss={loss_val!r}, loss_finite={loss_finite}, "
f"grads_finite={grads_finite})",
flush=True,
)
return
sep = "=" * 78
lines: List[str] = []
lines.append("")
lines.append(sep)
header = f"NON-FINITE LOSS DETECTED (occurrence #{count})"
if context:
header += f" context='{context}'"
lines.append(header)
lines.append(sep)
lines.append(f"total loss : {loss_val!r}")
lines.append(f"loss finite : {loss_finite}")
lines.append(f"grads finite : {grads_finite}")
# Per-target breakdown via the LossState formatter.
if state is not None:
try:
with torch.no_grad():
state.aggregate(log_values=True)
lines.append("")
lines.append("per-target breakdown (eager re-aggregation):")
try:
lines.append(state.format_breakdown())
except AttributeError:
lines.append(" (LossState.format_breakdown unavailable)")
for name, lt in getattr(state, "_losses", {}).items():
w = state.get_effective_weight(name)
try:
v = lt.item()
except Exception:
v = float("nan")
finite = "yes" if (v == v and abs(v) != float("inf")) else "NO"
lines.append(
f" {name:<32} w={w:>8.4g} loss={v:>14.6g} {finite}"
)
except Exception as exc:
lines.append(f" (state.aggregate raised during diagnostic: {exc!r})")
# Parameter / gradient non-finite entry counts.
if parameters:
lines.append("")
lines.append("parameter / gradient inspection:")
for label, n_bad in _nonfinite_counts(parameters):
marker = " <-- NaN/Inf" if n_bad else ""
lines.append(f" {label}: non_finite={n_bad}{marker}")
lines.append(sep)
print("\n".join(lines), flush=True)