Source code for torchref.utils.autograd_introspection

"""
Autograd graph introspection.

Walk a loss tensor's autograd graph backward to discover the leaf
``nn.Parameter``s that gradient will accumulate into. Used by
:class:`torchref.refinement.loss_state.LossState` to record, at target
registration time, which leaves each loss touches — so the per-step
optimization path can automatically disable ``requires_grad`` on parameters
the loss depends on but the optimizer wasn't constructed with.
"""

from typing import Iterable, Mapping, Set, Union

import torch
from torch import nn

LossLike = Union[
    torch.Tensor,
    Iterable[torch.Tensor],
    Mapping[str, torch.Tensor],
]


def _iter_roots(losses: LossLike):
    """Flatten losses to an iterable of root tensors.

    Accepts a single ``Tensor``, a tuple/list of tensors, a dict whose
    values are tensors, or any nested combination thereof. Non-tensor
    entries (``None``, Python scalars, etc.) are silently skipped.
    """
    if isinstance(losses, torch.Tensor):
        yield losses
        return
    if isinstance(losses, Mapping):
        for v in losses.values():
            yield from _iter_roots(v)
        return
    if isinstance(losses, Iterable):
        for v in losses:
            yield from _iter_roots(v)
        return
    # Non-tensor, non-iterable: ignore.


[docs] def collect_loss_leaves(losses: LossLike) -> Set[nn.Parameter]: """Return the set of leaf ``nn.Parameter``s that gradient will accumulate into when ``backward()`` is called on the given loss(es). Walks the autograd graph from each root tensor's ``grad_fn`` and finds every ``AccumulateGrad`` node, collecting its ``.variable`` when it is an :class:`nn.Parameter`. Multiple roots are unioned via a single shared traversal so that shared subgraphs (e.g. two losses both depending on the same model forward) are walked exactly once. Parameters ---------- losses : Tensor | Iterable[Tensor] | Mapping[str, Tensor] One or more loss tensors. Returns ------- set of nn.Parameter Leaf parameters that backward would accumulate gradient into. A leaf with ``requires_grad=False`` does not appear (no ``AccumulateGrad`` node is created for it). Detached subtrees contribute nothing. """ # Use the grad_fn object itself as the seen key. grad_fn instances are # hashable and equality-compares to the same underlying C++ Node, so this # is correct. We deliberately avoid id()-keying because Python wrapper # objects returned by ``next_functions`` are short-lived — once popped # off the stack and out of scope, their id can be reused for an unrelated # wrapper, causing the seen set to incorrectly skip live nodes. seen = set() leaves: Set[nn.Parameter] = set() stack = [] for loss in _iter_roots(losses): grad_fn = loss.grad_fn if grad_fn is not None: stack.append(grad_fn) while stack: fn = stack.pop() if fn is None or fn in seen: continue seen.add(fn) var = getattr(fn, "variable", None) if isinstance(var, nn.Parameter): leaves.add(var) for next_fn, _ in getattr(fn, "next_functions", ()): if next_fn is not None: stack.append(next_fn) return leaves