torchref.utils.autograd_introspection module

Autograd graph introspection.

Walk a loss tensor’s autograd graph backward to discover the leaf nn.Parameter``s that gradient will accumulate into. Used by :class:`torchref.refinement.loss_state.LossState` to record, at target registration time, which leaves each loss touches so the per-step optimization path can automatically disable ``requires_grad on parameters the loss depends on but the optimizer wasn’t constructed with.

torchref.utils.autograd_introspection.collect_loss_leaves(losses)[source]

Return the set of leaf nn.Parameter``s that gradient will accumulate into when ``backward() is called on the given loss(es).

Walks the autograd graph from each root tensor’s grad_fn and finds every AccumulateGrad node, collecting its .variable when it is an nn.Parameter.

Multiple roots are unioned via a single shared traversal so that shared subgraphs (e.g. two losses both depending on the same model forward) are walked exactly once.

Parameters:

losses (Tensor | Iterable[Tensor] | Mapping[str, Tensor]) – One or more loss tensors.

Returns:

Leaf parameters that backward would accumulate gradient into. A leaf with requires_grad=False does not appear (no AccumulateGrad node is created for it). Detached subtrees contribute nothing.

Return type:

set of nn.Parameter