Source code for torchref.utils.autograd_ops

"""Shared autograd helpers that wrap common ``tensor[indices]`` gathers
with cheap, deterministic backwards.

PyTorch's default backward for a 1-D ``tensor[indices]`` lowers to
``aten::_index_put_impl_(accumulate=True)``. On CUDA that path uses
``cub::DeviceRadixSortOnesweepKernel`` to sort the indices and then runs
a deduplicated scatter — a meaningful constant overhead that shows up
prominently in profiles even for small accumulator buffers (e.g.
``log_scale[bins]`` indexing into a 20-element vector).

Wrapping the gather in this autograd op routes the backward through
``index_add_`` instead: a single atomic-accumulating scatter with no
radix sort. The forward output is identical to ``buffer[indices]``.

Use via :func:`gather_with_index_add` (drop-in replacement for
``buffer[indices]`` in differentiable code paths).
"""

from __future__ import annotations

import torch


class _GatherWithIndexAdd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, buffer, indices):
        ctx.save_for_backward(indices)
        ctx.buffer_shape = buffer.shape
        ctx.buffer_dtype = buffer.dtype
        ctx.buffer_device = buffer.device
        return buffer[indices]

    @staticmethod
    def backward(ctx, grad_out):
        (indices,) = ctx.saved_tensors
        grad_buffer = torch.zeros(
            ctx.buffer_shape, dtype=ctx.buffer_dtype, device=ctx.buffer_device,
        )
        # Atomic accumulating scatter along dim 0 — replaces the default
        # ``index_put_(accumulate=True)`` backward.
        grad_buffer.index_add_(0, indices, grad_out)
        return grad_buffer, None


[docs] def gather_with_index_add( buffer: torch.Tensor, indices: torch.Tensor, ) -> torch.Tensor: """``buffer[indices]`` with a fast ``index_add_`` backward. Drop-in replacement for the 1-D gather pattern when the forward is differentiable and the indices may contain duplicates. Parameters ---------- buffer : torch.Tensor Source tensor (1-D, or higher-D with indexing on dim 0). indices : torch.Tensor Integer LongTensor of indices into ``buffer`` along dim 0. Returns ------- torch.Tensor ``buffer[indices]`` — identical to plain indexing in forward, with a cheaper backward path. """ return _GatherWithIndexAdd.apply(buffer, indices)