torchref.utils.gradnorm module
Gradient norm utilities for optimization monitoring.
This module provides functions to compute gradient norms for monitoring training stability and debugging optimization issues.
- torchref.utils.gradnorm.gradnorm(loss, parameters)[source]
Compute the gradient norm of a loss with respect to given parameters.
Performs a backward pass with graph retention and computes the RMS (root mean square) of all gradients concatenated together.
- Parameters:
loss (torch.Tensor) – The loss tensor to backpropagate.
parameters (iterable) – Iterable of model parameters (typically from model.parameters()).
- Returns:
The computed RMS gradient norm.
- Return type:
Notes
Uses retain_graph=True to allow subsequent backward passes. Only includes parameters that have gradients (skips None grads).
Examples
loss = model(input) grad_norm = gradnorm(loss, model.parameters()) print(f"Gradient norm: {grad_norm:.4f}")