torchref.base.kernels.cpu_scatter module
Parallel partitioned scatter_add for structured (wa + wbwc) indices.
This kernel is float32-only by design: the C++ kernel hardcodes
torch::kFloat32 for the density/output buffers. The Python wrapper
auto-casts non-float32 inputs to float32 and emits a one-time warning per
input dtype — so callers running with dtypes.float = torch.float64 keep
working, just at float32 precision through this kernel.
Uses a C++ kernel compiled via torch.utils.cpp_extension.load_inline. Each thread owns a contiguous output partition and only accumulates scatter elements that fall in its range. With atoms sorted by 1D center, per-thread early-exit skips ~(T-1)/T of atoms — giving near-linear scaling.
- Threading backend is selected at compile time:
OpenMP on Linux (and any compiler that accepts -fopenmp)
std::thread fallback otherwise (notably Apple Clang on macOS, whose -fopenmp is rejected because the OpenMP runtime is not bundled)
- Two index dtypes are exposed:
- int32 — default fast path. Fits any realistic crystallographic grid
(nx*ny*nz < 2**31, roughly 1300^3 voxels) and halves the index bandwidth of the inner loop.
- int64 — fallback for the rare edge cases that exceed INT32_MAX
(e.g. very large cryo-EM grids).
The Python wrapper picks the binding from wa.dtype.
Autograd: forward is the custom C++ scatter, backward is a standard gather (embarrassingly parallel, no custom kernel needed).
- torchref.base.kernels.cpu_scatter.structured_scatter_add(density_cube, wa, wbwc, map_size)[source]
Differentiable parallel scatter_add using structured indices.
Dispatches to the int32 or int64 kernel based on
wa.dtype. int32 is the default fast path (halves index bandwidth); int64 is available for grids larger than INT32_MAX voxels.- Parameters:
density_cube (Tensor (C, nx, ny, nz) float32) – Values to scatter (from _separable_density).
wa (Tensor (C, nx) int32 or int64) – Precomputed x-axis scatter indices.
wbwc (Tensor (C, ny, nz) same dtype as wa) – Precomputed yz-plane scatter indices.
map_size (int) – Total number of voxels in flat density map. For int32 indices, must be <= INT32_MAX.
- Returns:
Scattered result. Differentiable w.r.t. density_cube.
- Return type:
Tensor (map_size,) float32