torchref.base.kernels.cpu_scatter module

Parallel partitioned scatter_add for structured (wa + wbwc) indices.

This kernel is float32-only by design: the C++ kernel hardcodes torch::kFloat32 for the density/output buffers. The Python wrapper auto-casts non-float32 inputs to float32 and emits a one-time warning per input dtype — so callers running with dtypes.float = torch.float64 keep working, just at float32 precision through this kernel.

Uses a C++ kernel compiled via torch.utils.cpp_extension.load_inline. Each thread owns a contiguous output partition and only accumulates scatter elements that fall in its range. With atoms sorted by 1D center, per-thread early-exit skips ~(T-1)/T of atoms — giving near-linear scaling.

Threading backend is selected at compile time:
  • OpenMP on Linux (and any compiler that accepts -fopenmp)

  • std::thread fallback otherwise (notably Apple Clang on macOS, whose -fopenmp is rejected because the OpenMP runtime is not bundled)

Two index dtypes are exposed:
  • int32 — default fast path. Fits any realistic crystallographic grid

    (nx*ny*nz < 2**31, roughly 1300^3 voxels) and halves the index bandwidth of the inner loop.

  • int64 — fallback for the rare edge cases that exceed INT32_MAX

    (e.g. very large cryo-EM grids).

The Python wrapper picks the binding from wa.dtype.

Autograd: forward is the custom C++ scatter, backward is a standard gather (embarrassingly parallel, no custom kernel needed).

torchref.base.kernels.cpu_scatter.structured_scatter_add(density_cube, wa, wbwc, map_size)[source]

Differentiable parallel scatter_add using structured indices.

Dispatches to the int32 or int64 kernel based on wa.dtype. int32 is the default fast path (halves index bandwidth); int64 is available for grids larger than INT32_MAX voxels.

Parameters:
  • density_cube (Tensor (C, nx, ny, nz) float32) – Values to scatter (from _separable_density).

  • wa (Tensor (C, nx) int32 or int64) – Precomputed x-axis scatter indices.

  • wbwc (Tensor (C, ny, nz) same dtype as wa) – Precomputed yz-plane scatter indices.

  • map_size (int) – Total number of voxels in flat density map. For int32 indices, must be <= INT32_MAX.

Returns:

Scattered result. Differentiable w.r.t. density_cube.

Return type:

Tensor (map_size,) float32