torchref.utils.caching module
Caching utilities for TorchRef modules.
Provides ParameterFingerprint for lightweight parameter-change detection
and CachedForwardMixin for automatic caching of forward() results
with invalidation on parameter mutation or backward propagation.
- class torchref.utils.caching.ParameterFingerprint(params=())[source]
Bases:
objectLightweight fingerprint for detecting parameter changes.
Captures (data_ptr, _version, numel) per tensor. Comparison is O(n_params) integer comparisons — much cheaper than SHA-1 hashing.
- class torchref.utils.caching.CachedForwardMixin[source]
Bases:
objectMixin that caches
forward()results with automatic invalidation.Overrides
__call__to return a cached result when the module’s parameters, buffers, and call arguments have not changed since the last invocation — and no backward pass has propagated through the cached output.Cache invalidation triggers:
Any parameter or buffer
data_ptror_versionchange (covers optimizer in-place updates and mask/parameter replacement).Input tensor
data_ptror_versionchange, or non-tensor argument value change.A backward pass through the cached output (increments generation counter via a gradient hook).
The cached tensor retains its autograd graph — gradients flow correctly on the first backward pass, after which the cache is invalidated.