torchref.utils.autograd_ops module
Shared autograd helpers that wrap common tensor[indices] gathers
with cheap, deterministic backwards.
PyTorch’s default backward for a 1-D tensor[indices] lowers to
aten::_index_put_impl_(accumulate=True). On CUDA that path uses
cub::DeviceRadixSortOnesweepKernel to sort the indices and then runs
a deduplicated scatter — a meaningful constant overhead that shows up
prominently in profiles even for small accumulator buffers (e.g.
log_scale[bins] indexing into a 20-element vector).
Wrapping the gather in this autograd op routes the backward through
index_add_ instead: a single atomic-accumulating scatter with no
radix sort. The forward output is identical to buffer[indices].
Use via gather_with_index_add() (drop-in replacement for
buffer[indices] in differentiable code paths).
- torchref.utils.autograd_ops.gather_with_index_add(buffer, indices)[source]
buffer[indices]with a fastindex_add_backward.Drop-in replacement for the 1-D gather pattern when the forward is differentiable and the indices may contain duplicates.
- Parameters:
buffer (torch.Tensor) – Source tensor (1-D, or higher-D with indexing on dim 0).
indices (torch.Tensor) – Integer LongTensor of indices into
bufferalong dim 0.
- Returns:
buffer[indices]— identical to plain indexing in forward, with a cheaper backward path.- Return type: