Source code for torchref.utils.gradnorm

"""
Gradient norm utilities for optimization monitoring.

This module provides functions to compute gradient norms for monitoring
training stability and debugging optimization issues.
"""

import torch


[docs] def gradnorm(loss: torch.Tensor, parameters: iter) -> float: """ Compute the gradient norm of a loss with respect to given parameters. Performs a backward pass with graph retention and computes the RMS (root mean square) of all gradients concatenated together. Parameters ---------- loss : torch.Tensor The loss tensor to backpropagate. parameters : iterable Iterable of model parameters (typically from model.parameters()). Returns ------- float The computed RMS gradient norm. Notes ----- Uses retain_graph=True to allow subsequent backward passes. Only includes parameters that have gradients (skips None grads). Examples -------- :: loss = model(input) grad_norm = gradnorm(loss, model.parameters()) print(f"Gradient norm: {grad_norm:.4f}") """ loss.backward(retain_graph=True) grad_norm = ( torch.mean( torch.cat([p.grad.view(-1) for p in parameters if p.grad is not None]) ** 2 ) ** 0.5 ) return grad_norm