torchref.utils.caching module

Caching utilities for TorchRef modules.

Provides ParameterFingerprint for lightweight parameter-change detection and CachedForwardMixin for automatic caching of forward() results with invalidation on parameter mutation or backward propagation.

class torchref.utils.caching.ParameterFingerprint(params=())[source]

Bases: object

Lightweight fingerprint for detecting parameter changes.

Captures (data_ptr, _version, numel) per tensor. Comparison is O(n_params) integer comparisons — much cheaper than SHA-1 hashing.

__init__(params=())[source]
matches(params)[source]

Return True if params have the same fingerprint.

class torchref.utils.caching.CachedForwardMixin[source]

Bases: object

Mixin that caches forward() results with automatic invalidation.

Overrides __call__ to return a cached result when the module’s parameters, buffers, and call arguments have not changed since the last invocation — and no backward pass has propagated through the cached output.

Cache invalidation triggers:

  • Any parameter or buffer data_ptr or _version change (covers optimizer in-place updates and mask/parameter replacement).

  • Input tensor data_ptr or _version change, or non-tensor argument value change.

  • A backward pass through the cached output (increments generation counter via a gradient hook).

The cached tensor retains its autograd graph — gradients flow correctly on the first backward pass, after which the cache is invalidated.

__call__(*args, recalc=False, **kwargs)[source]

Return cached forward() result, or recompute on cache miss.

Parameters:

recalc (bool, optional) – If True, invalidate the cache and force recomputation. Not forwarded to forward().

reset_forward_cache()[source]

Manually invalidate the forward cache.