torchref.utils.autograd_ops module

Shared autograd helpers that wrap common tensor[indices] gathers with cheap, deterministic backwards.

PyTorch’s default backward for a 1-D tensor[indices] lowers to aten::_index_put_impl_(accumulate=True). On CUDA that path uses cub::DeviceRadixSortOnesweepKernel to sort the indices and then runs a deduplicated scatter — a meaningful constant overhead that shows up prominently in profiles even for small accumulator buffers (e.g. log_scale[bins] indexing into a 20-element vector).

Wrapping the gather in this autograd op routes the backward through index_add_ instead: a single atomic-accumulating scatter with no radix sort. The forward output is identical to buffer[indices].

Use via gather_with_index_add() (drop-in replacement for buffer[indices] in differentiable code paths).

torchref.utils.autograd_ops.gather_with_index_add(buffer, indices)[source]

buffer[indices] with a fast index_add_ backward.

Drop-in replacement for the 1-D gather pattern when the forward is differentiable and the indices may contain duplicates.

Parameters:
  • buffer (torch.Tensor) – Source tensor (1-D, or higher-D with indexing on dim 0).

  • indices (torch.Tensor) – Integer LongTensor of indices into buffer along dim 0.

Returns:

buffer[indices] — identical to plain indexing in forward, with a cheaper backward path.

Return type:

torch.Tensor