"""Parallel partitioned scatter_add for structured (wa + wbwc) indices.
This kernel is float32-only by design: the C++ kernel hardcodes
``torch::kFloat32`` for the density/output buffers. The Python wrapper
auto-casts non-float32 inputs to float32 and emits a one-time warning per
input dtype — so callers running with ``dtypes.float = torch.float64`` keep
working, just at float32 precision through this kernel.
Uses a C++ kernel compiled via torch.utils.cpp_extension.load_inline.
Each thread owns a contiguous output partition and only accumulates scatter
elements that fall in its range. With atoms sorted by 1D center, per-thread
early-exit skips ~(T-1)/T of atoms — giving near-linear scaling.
Threading backend is selected at compile time:
* OpenMP on Linux (and any compiler that accepts -fopenmp)
* std::thread fallback otherwise (notably Apple Clang on macOS, whose
-fopenmp is rejected because the OpenMP runtime is not bundled)
Two index dtypes are exposed:
* int32 — default fast path. Fits any realistic crystallographic grid
(nx*ny*nz < 2**31, roughly 1300^3 voxels) and halves the
index bandwidth of the inner loop.
* int64 — fallback for the rare edge cases that exceed INT32_MAX
(e.g. very large cryo-EM grids).
The Python wrapper picks the binding from wa.dtype.
Autograd: forward is the custom C++ scatter, backward is a standard gather
(embarrassingly parallel, no custom kernel needed).
"""
import warnings
import torch
from torch.utils.cpp_extension import load_inline
_WARNED_CAST_DTYPES: set = set()
_CPP_SRC = r"""
#include <torch/extension.h>
#include <cstdint>
#ifdef _OPENMP
#include <omp.h>
#else
#include <thread>
#include <vector>
#endif
// Parallel partitioned scatter_add from structured indices.
//
// Output space is divided among threads. Each thread only writes to its own
// [lo, hi) segment — zero synchronization.
// Atoms sorted by 1D center give efficient early-exit per thread.
//
// IdxT is int32_t (default fast path) or int64_t (for huge grids).
template <typename IdxT>
torch::Tensor structured_scatter_add_impl(
torch::Tensor output,
torch::Tensor wa,
torch::Tensor wbwc,
torch::Tensor values,
int64_t nx, int64_t ny, int64_t nz)
{
TORCH_CHECK(output.is_contiguous() && output.scalar_type() == torch::kFloat32);
TORCH_CHECK(values.is_contiguous() && values.scalar_type() == torch::kFloat32);
TORCH_CHECK(wa.is_contiguous());
TORCH_CHECK(wbwc.is_contiguous());
const int64_t M = output.size(0);
const int64_t C = wa.size(0);
const int64_t ny_nz = ny * nz;
const int64_t nxyz = nx * ny_nz;
float* __restrict__ out_p = output.data_ptr<float>();
const IdxT* __restrict__ wa_p = wa.data_ptr<IdxT>();
const IdxT* __restrict__ wbwc_p = wbwc.data_ptr<IdxT>();
const float* __restrict__ val_p = values.data_ptr<float>();
// Global wbwc bounds (for atom-level early exit)
IdxT wbwc_min = wbwc_p[0], wbwc_max = wbwc_p[0];
for (int64_t i = 1; i < C * ny_nz; i++) {
IdxT v = wbwc_p[i];
if (v < wbwc_min) wbwc_min = v;
if (v > wbwc_max) wbwc_max = v;
}
auto worker = [&](int64_t lo, int64_t hi) {
for (int64_t c = 0; c < C; c++) {
const IdxT* wa_row = wa_p + c * nx;
// Atom-level early exit: check wa range vs partition
IdxT amin = wa_row[0], amax = wa_row[0];
for (int64_t i = 1; i < nx; i++) {
IdxT v = wa_row[i];
if (v < amin) amin = v;
if (v > amax) amax = v;
}
if ((int64_t)amax + wbwc_max < lo || (int64_t)amin + wbwc_min >= hi) continue;
const IdxT* wbwc_row = wbwc_p + c * ny_nz;
const float* val_base = val_p + c * nxyz;
for (int64_t ix = 0; ix < nx; ix++) {
IdxT wa_val = wa_row[ix];
// x-offset early exit
if ((int64_t)wa_val + wbwc_max < lo || (int64_t)wa_val + wbwc_min >= hi) continue;
const float* val_ix = val_base + ix * ny_nz;
for (int64_t iyz = 0; iyz < ny_nz; iyz++) {
int64_t idx = (int64_t)wa_val + (int64_t)wbwc_row[iyz];
if (idx >= lo && idx < hi) {
out_p[idx] += val_ix[iyz];
}
}
}
}
};
#ifdef _OPENMP
#pragma omp parallel
{
int tid = omp_get_thread_num();
int nth = omp_get_num_threads();
int64_t lo = (int64_t)tid * M / nth;
int64_t hi = (int64_t)(tid + 1) * M / nth;
worker(lo, hi);
}
#else
int nth = (int)std::thread::hardware_concurrency();
if (nth < 1) nth = 1;
std::vector<std::thread> threads;
threads.reserve(nth);
for (int tid = 0; tid < nth; tid++) {
int64_t lo = (int64_t)tid * M / nth;
int64_t hi = (int64_t)(tid + 1) * M / nth;
threads.emplace_back(worker, lo, hi);
}
for (auto& t : threads) t.join();
#endif
return output;
}
// Parallel structured gather (backward of scatter_add).
//
// grad_cube[c, ix, iyz] = grad_output[wa[c,ix] + wbwc[c,iyz]]
//
// Each atom's output is independent — parallelize over atoms directly.
// No index tensor allocation, no fancy indexing overhead.
template <typename IdxT>
torch::Tensor structured_gather_impl(
torch::Tensor grad_output,
torch::Tensor wa,
torch::Tensor wbwc,
int64_t nx, int64_t ny, int64_t nz)
{
TORCH_CHECK(grad_output.is_contiguous() && grad_output.scalar_type() == torch::kFloat32);
TORCH_CHECK(wa.is_contiguous());
TORCH_CHECK(wbwc.is_contiguous());
const int64_t C = wa.size(0);
const int64_t ny_nz = ny * nz;
const int64_t nxyz = nx * ny_nz;
auto grad_cube = torch::empty({C, nxyz}, grad_output.options());
const float* __restrict__ go_p = grad_output.data_ptr<float>();
const IdxT* __restrict__ wa_p = wa.data_ptr<IdxT>();
const IdxT* __restrict__ wbwc_p = wbwc.data_ptr<IdxT>();
float* __restrict__ gc_p = grad_cube.data_ptr<float>();
auto worker_gather = [&](int64_t c_lo, int64_t c_hi) {
for (int64_t c = c_lo; c < c_hi; c++) {
const IdxT* wa_row = wa_p + c * nx;
const IdxT* wbwc_row = wbwc_p + c * ny_nz;
float* out_row = gc_p + c * nxyz;
for (int64_t ix = 0; ix < nx; ix++) {
IdxT wa_val = wa_row[ix];
float* dst = out_row + ix * ny_nz;
for (int64_t iyz = 0; iyz < ny_nz; iyz++) {
dst[iyz] = go_p[(int64_t)wa_val + (int64_t)wbwc_row[iyz]];
}
}
}
};
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
for (int64_t c = 0; c < C; c++) {
const IdxT* wa_row = wa_p + c * nx;
const IdxT* wbwc_row = wbwc_p + c * ny_nz;
float* out_row = gc_p + c * nxyz;
for (int64_t ix = 0; ix < nx; ix++) {
IdxT wa_val = wa_row[ix];
float* dst = out_row + ix * ny_nz;
for (int64_t iyz = 0; iyz < ny_nz; iyz++) {
dst[iyz] = go_p[(int64_t)wa_val + (int64_t)wbwc_row[iyz]];
}
}
}
#else
int nth = (int)std::thread::hardware_concurrency();
if (nth < 1) nth = 1;
std::vector<std::thread> threads;
threads.reserve(nth);
for (int tid = 0; tid < nth; tid++) {
int64_t c_lo = (int64_t)tid * C / nth;
int64_t c_hi = (int64_t)(tid + 1) * C / nth;
threads.emplace_back(worker_gather, c_lo, c_hi);
}
for (auto& t : threads) t.join();
#endif
return grad_cube;
}
// Wrappers that enforce the matching dtype, so a wrong-dtype tensor gets a
// clear error instead of being reinterpreted by data_ptr<IdxT>().
torch::Tensor structured_scatter_add_i32(
torch::Tensor output, torch::Tensor wa, torch::Tensor wbwc, torch::Tensor values,
int64_t nx, int64_t ny, int64_t nz)
{
TORCH_CHECK(wa.scalar_type() == torch::kInt32, "wa must be int32");
TORCH_CHECK(wbwc.scalar_type() == torch::kInt32, "wbwc must be int32");
return structured_scatter_add_impl<int32_t>(output, wa, wbwc, values, nx, ny, nz);
}
torch::Tensor structured_scatter_add_i64(
torch::Tensor output, torch::Tensor wa, torch::Tensor wbwc, torch::Tensor values,
int64_t nx, int64_t ny, int64_t nz)
{
TORCH_CHECK(wa.scalar_type() == torch::kInt64, "wa must be int64");
TORCH_CHECK(wbwc.scalar_type() == torch::kInt64, "wbwc must be int64");
return structured_scatter_add_impl<int64_t>(output, wa, wbwc, values, nx, ny, nz);
}
torch::Tensor structured_gather_i32(
torch::Tensor grad_output, torch::Tensor wa, torch::Tensor wbwc,
int64_t nx, int64_t ny, int64_t nz)
{
TORCH_CHECK(wa.scalar_type() == torch::kInt32, "wa must be int32");
TORCH_CHECK(wbwc.scalar_type() == torch::kInt32, "wbwc must be int32");
return structured_gather_impl<int32_t>(grad_output, wa, wbwc, nx, ny, nz);
}
torch::Tensor structured_gather_i64(
torch::Tensor grad_output, torch::Tensor wa, torch::Tensor wbwc,
int64_t nx, int64_t ny, int64_t nz)
{
TORCH_CHECK(wa.scalar_type() == torch::kInt64, "wa must be int64");
TORCH_CHECK(wbwc.scalar_type() == torch::kInt64, "wbwc must be int64");
return structured_gather_impl<int64_t>(grad_output, wa, wbwc, nx, ny, nz);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("structured_scatter_add_i32", &structured_scatter_add_i32,
"Partitioned scatter_add with int32 structured (wa, wbwc) indices");
m.def("structured_scatter_add_i64", &structured_scatter_add_i64,
"Partitioned scatter_add with int64 structured (wa, wbwc) indices");
m.def("structured_gather_i32", &structured_gather_i32,
"Structured gather (backward) with int32 indices");
m.def("structured_gather_i64", &structured_gather_i64,
"Structured gather (backward) with int64 indices");
}
"""
# ---------------------------------------------------------------------------
# Lazy compilation with POSIX lockf locking (works across cluster nodes).
#
# PyTorch's load_inline uses FileBaton (file-existence lock) which stays
# behind when a process is killed mid-compile, blocking all future imports.
# We wrap it with fcntl.lockf (POSIX record locks) which:
# - are enforced by the filesystem → work across NFS/GPFS cluster nodes
# - are released by the kernel on process death (even SIGKILL)
# First process compiles; all others (same node or different) reuse cache.
# ---------------------------------------------------------------------------
_module = None
_module_failed = False
# Captured diagnostic from the most recent failed compile attempt. Populated
# by _get_module() on failure so test/debug code can surface the underlying
# error instead of just "module not available". Format: (error_str, traceback_str).
_module_error: "tuple[str, str] | None" = None
def _get_module():
global _module, _module_failed, _module_error
if _module is not None:
return _module
if _module_failed:
return None
import os
import sys
import traceback
try:
import fcntl
except ImportError:
# fcntl is not available on non-POSIX platforms (e.g. Windows)
_module_failed = True
_module_error = ("fcntl unavailable (non-POSIX platform)", "")
return None
try:
# Ensure ninja (installed via pip) is on PATH for compute nodes
bin_dir = os.path.dirname(sys.executable)
if bin_dir not in os.environ.get("PATH", ""):
os.environ["PATH"] = bin_dir + ":" + os.environ.get("PATH", "")
# Need GCC >= 9 for PyTorch C++ extensions
for toolset in ("14", "13", "12"):
gcc = f"/opt/rh/gcc-toolset-{toolset}/root/usr/bin/g++"
if os.path.isfile(gcc):
os.environ["CXX"] = gcc
os.environ["CC"] = gcc.replace("g++", "gcc")
break
# Per-microarchitecture build directory — prevents Illegal Instruction
# when different cluster nodes have different CPUs (e.g., AMD vs Intel).
import platform
cpu_tag = platform.machine()
try:
# Use the CPU model to distinguish microarchitectures
with open("/proc/cpuinfo") as f:
for line in f:
if line.startswith("model name"):
# e.g. "EPYC_7443P" or "Xeon_Gold_6248"
cpu_tag = line.split(":")[1].strip().replace(" ", "_")
break
except OSError:
pass
build_dir = os.path.join(
os.environ.get(
"TORCH_EXTENSIONS_DIR",
os.path.join(os.path.expanduser("~"), ".cache", "torch_extensions"),
),
f"cpu_scatter_{cpu_tag}",
)
os.makedirs(build_dir, exist_ok=True)
# Apple Clang on macOS rejects -fopenmp (no bundled OpenMP runtime).
# Kernel falls back to std::thread via #ifndef _OPENMP — no libomp,
# no Homebrew, no external dependency required.
is_apple_clang = sys.platform == "darwin"
extra_cflags = ["-O3", "-march=native"]
extra_ldflags: list[str] = []
if not is_apple_clang:
extra_cflags.append("-fopenmp")
extra_ldflags.append("-fopenmp")
# fcntl.lockf uses POSIX record locks (fcntl F_SETLKW) which are:
# 1. filesystem-level → work across NFS/GPFS cluster nodes
# 2. released by kernel on process death, even SIGKILL
lock_fd = os.open(
os.path.join(build_dir, "compile.lock"), os.O_CREAT | os.O_RDWR
)
try:
fcntl.lockf(lock_fd, fcntl.LOCK_EX)
# Clear any stale PyTorch FileBaton lock from a killed process
try:
os.unlink(os.path.join(build_dir, "lock"))
except FileNotFoundError:
pass
_module = load_inline(
name="cpu_scatter",
cpp_sources=[_CPP_SRC],
extra_cflags=extra_cflags,
extra_ldflags=extra_ldflags,
build_directory=build_dir,
verbose=False,
)
finally:
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
except Exception as e:
_module_failed = True
_module_error = (f"{type(e).__name__}: {e}", traceback.format_exc())
return None
return _module
# ---------------------------------------------------------------------------
# Autograd wrapper
# ---------------------------------------------------------------------------
class _StructuredScatterAdd(torch.autograd.Function):
"""scatter_add with structured (wa, wbwc) indices.
Forward: C++ partitioned scatter (parallel, no conflicts)
Backward: standard gather (embarrassingly parallel)
"""
@staticmethod
def forward(ctx, density_cube, wa, wbwc, map_size):
# density_cube: (C, nx, ny, nz) — requires grad, must be float32
# wa: (C, nx) — int32 or int64, no grad
# wbwc: (C, ny, nz) — same dtype as wa
if density_cube.dtype != torch.float32:
if density_cube.dtype not in _WARNED_CAST_DTYPES:
warnings.warn(
f"cpu_scatter is float32-only; casting density_cube from "
f"{density_cube.dtype} to float32 (precision will be reduced).",
stacklevel=3,
)
_WARNED_CAST_DTYPES.add(density_cube.dtype)
density_cube = density_cube.to(torch.float32)
if wa.dtype != wbwc.dtype:
raise TypeError(
f"wa.dtype ({wa.dtype}) and wbwc.dtype ({wbwc.dtype}) must match"
)
if wa.dtype == torch.int32:
if map_size > _INT32_MAX:
raise RuntimeError(
f"map_size {map_size} exceeds INT32_MAX ({_INT32_MAX}); "
"pass int64 indices for grids this large."
)
scatter_fn_name = "structured_scatter_add_i32"
gather_fn_name = "structured_gather_i32"
elif wa.dtype == torch.int64:
scatter_fn_name = "structured_scatter_add_i64"
gather_fn_name = "structured_gather_i64"
else:
raise TypeError(
f"wa.dtype must be int32 or int64, got {wa.dtype}"
)
C, nx, ny, nz = density_cube.shape
ctx.save_for_backward(wa, wbwc)
ctx.cube_shape = density_cube.shape
ctx.gather_fn_name = gather_fn_name
mod = _get_module()
if mod is None:
err = _module_error[0] if _module_error else "unknown reason"
raise RuntimeError(
f"C++ cpu_scatter module not available ({err}). "
"See torchref.base.kernels.cpu_scatter._module_error for the full traceback."
)
result = torch.zeros(map_size, dtype=density_cube.dtype,
device=density_cube.device)
getattr(mod, scatter_fn_name)(
result,
wa.contiguous(),
wbwc.reshape(C, ny * nz).contiguous(),
density_cube.reshape(C, nx * ny * nz).contiguous(),
nx, ny, nz,
)
return result
@staticmethod
def backward(ctx, grad_output):
wa, wbwc = ctx.saved_tensors
C, nx, ny, nz = ctx.cube_shape
mod = _get_module()
if mod is None:
err = _module_error[0] if _module_error else "unknown reason"
raise RuntimeError(
f"C++ cpu_scatter module not available ({err}). "
"See torchref.base.kernels.cpu_scatter._module_error for the full traceback."
)
grad_cube = getattr(mod, ctx.gather_fn_name)(
grad_output.contiguous(),
wa.contiguous(),
wbwc.reshape(C, ny * nz).contiguous(),
nx, ny, nz,
)
return grad_cube.reshape(C, nx, ny, nz), None, None, None
_INT32_MAX = 2**31 - 1
[docs]
def structured_scatter_add(density_cube, wa, wbwc, map_size):
"""Differentiable parallel scatter_add using structured indices.
Dispatches to the int32 or int64 kernel based on ``wa.dtype``. int32 is
the default fast path (halves index bandwidth); int64 is available for
grids larger than INT32_MAX voxels.
Parameters
----------
density_cube : Tensor (C, nx, ny, nz) float32
Values to scatter (from _separable_density).
wa : Tensor (C, nx) int32 or int64
Precomputed x-axis scatter indices.
wbwc : Tensor (C, ny, nz) same dtype as wa
Precomputed yz-plane scatter indices.
map_size : int
Total number of voxels in flat density map. For int32 indices,
must be <= INT32_MAX.
Returns
-------
Tensor (map_size,) float32
Scattered result. Differentiable w.r.t. density_cube.
"""
return _StructuredScatterAdd.apply(density_cube, wa, wbwc, map_size)