torchref.utils.autograd_introspection module
Autograd graph introspection.
Walk a loss tensor’s autograd graph backward to discover the leaf
nn.Parameter``s that gradient will accumulate into. Used by
:class:`torchref.refinement.loss_state.LossState` to record, at target
registration time, which leaves each loss touches — so the per-step
optimization path can automatically disable ``requires_grad on parameters
the loss depends on but the optimizer wasn’t constructed with.
- torchref.utils.autograd_introspection.collect_loss_leaves(losses)[source]
Return the set of leaf
nn.Parameter``s that gradient will accumulate into when ``backward()is called on the given loss(es).Walks the autograd graph from each root tensor’s
grad_fnand finds everyAccumulateGradnode, collecting its.variablewhen it is annn.Parameter.Multiple roots are unioned via a single shared traversal so that shared subgraphs (e.g. two losses both depending on the same model forward) are walked exactly once.
- Parameters:
losses (Tensor | Iterable[Tensor] | Mapping[str, Tensor]) – One or more loss tensors.
- Returns:
Leaf parameters that backward would accumulate gradient into. A leaf with
requires_grad=Falsedoes not appear (noAccumulateGradnode is created for it). Detached subtrees contribute nothing.- Return type:
set of nn.Parameter