Source code for torchref.utils.caching

"""
Caching utilities for TorchRef modules.

Provides ``ParameterFingerprint`` for lightweight parameter-change detection
and ``CachedForwardMixin`` for automatic caching of ``forward()`` results
with invalidation on parameter mutation or backward propagation.
"""

import torch


[docs] class ParameterFingerprint: """Lightweight fingerprint for detecting parameter changes. Captures (data_ptr, _version, numel) per tensor. Comparison is O(n_params) integer comparisons — much cheaper than SHA-1 hashing. """ __slots__ = ("_entries",)
[docs] def __init__(self, params=()): self._entries = tuple( (t.data_ptr(), t._version, t.numel()) for t in params )
[docs] def matches(self, params) -> bool: """Return True if *params* have the same fingerprint.""" other = tuple( (t.data_ptr(), t._version, t.numel()) for t in params ) return self._entries == other
def __bool__(self): return len(self._entries) > 0
[docs] class CachedForwardMixin: """Mixin that caches ``forward()`` results with automatic invalidation. Overrides ``__call__`` to return a cached result when the module's parameters, buffers, and call arguments have not changed since the last invocation — and no backward pass has propagated through the cached output. Cache invalidation triggers: * Any parameter or buffer ``data_ptr`` or ``_version`` change (covers optimizer in-place updates and mask/parameter replacement). * Input tensor ``data_ptr`` or ``_version`` change, or non-tensor argument value change. * A backward pass through the cached output (increments generation counter via a gradient hook). The cached tensor retains its autograd graph — gradients flow correctly on the first backward pass, after which the cache is invalidated. """ # ---- internal helpers ------------------------------------------------ def _fingerprint_state(self): """Fingerprint all parameters and buffers by ``(data_ptr, _version)``.""" entries = [] for t in self.parameters(): entries.append((t.data_ptr(), t._version)) for t in self.buffers(): entries.append((t.data_ptr(), t._version)) return tuple(entries) @staticmethod def _fingerprint_inputs(args, kwargs): """Fingerprint call arguments (tensor ptr/version, else by value).""" entries = [] for a in args: if isinstance(a, torch.Tensor): entries.append((a.data_ptr(), a._version)) else: entries.append(a) for k in sorted(kwargs): v = kwargs[k] if isinstance(v, torch.Tensor): entries.append((k, v.data_ptr(), v._version)) else: entries.append((k, v)) return tuple(entries) # ---- public API ------------------------------------------------------
[docs] def __call__(self, *args, recalc=False, **kwargs): """Return cached ``forward()`` result, or recompute on cache miss. Parameters ---------- recalc : bool, optional If True, invalidate the cache and force recomputation. Not forwarded to ``forward()``. """ if recalc: self.reset_forward_cache() cached = getattr(self, "_fwd_cached_output", None) if cached is not None: state_fp = self._fingerprint_state() input_fp = self._fingerprint_inputs(args, kwargs) gen = getattr(self, "_fwd_current_gen", 0) if ( state_fp == self._fwd_cached_state_fp and input_fp == self._fwd_cached_input_fp and gen == self._fwd_cache_gen ): return cached # Cache miss — recompute result = self.forward(*args, **kwargs) # Register backward hook to invalidate cache after gradient consumption if isinstance(result, torch.Tensor) and result.grad_fn is not None: def _bump_gen(grad, ref=self): ref._fwd_current_gen = getattr(ref, "_fwd_current_gen", 0) + 1 result.register_hook(_bump_gen) # Store cache state self._fwd_cached_output = result self._fwd_cached_state_fp = self._fingerprint_state() self._fwd_cached_input_fp = self._fingerprint_inputs(args, kwargs) if not hasattr(self, "_fwd_current_gen"): self._fwd_current_gen = 0 self._fwd_cache_gen = self._fwd_current_gen return result
[docs] def reset_forward_cache(self): """Manually invalidate the forward cache.""" self._fwd_cached_output = None self._fwd_cached_state_fp = None self._fwd_cached_input_fp = None self._fwd_cache_gen = 0 self._fwd_current_gen = 0