"""
Central electron density building with automatic backend selection.
Dispatches to the optimal implementation based on device and available backends:
GPU backends (``ISO_MAP_ENGINE_GPU``):
"separable_triton" — Tier 0, separable 1-D table lookups (default)
"fused_triton" — Tier 1, fused Triton kernel
"original" — Tier 2, find_relevant_voxels + vectorized_add_to_map
"auto" — try separable → fused → original (legacy default)
CPU backends (``ISO_MAP_ENGINE_CPU``):
"separable" — separable Gaussian splatting (default)
"separable_compiled" — torch.compile'd separable
"fused" — fused fractional-space kernel
"original" — find_relevant_voxels + vectorized_add_to_map (same as GPU Tier 2)
Override at import time via environment variables::
TORCHREF_ISO_MAP_ENGINE_GPU=auto
TORCHREF_ISO_MAP_ENGINE_CPU=fused
or at runtime by setting the module-level variables directly::
import torchref.base.electron_density.main as ed
ed.ISO_MAP_ENGINE_GPU = "original"
ed.ISO_MAP_ENGINE_CPU = "fused"
"""
import math
import os
from typing import Optional
import torch
from torchref.config import dtypes, get_float_dtype
# ---------------------------------------------------------------------------
# Engine selection — set via env vars or overwrite at runtime
# ---------------------------------------------------------------------------
ISO_MAP_ENGINE_GPU: str = os.environ.get("TORCHREF_ISO_MAP_ENGINE_GPU", "auto")
ISO_MAP_ENGINE_CPU: str = os.environ.get("TORCHREF_ISO_MAP_ENGINE_CPU", "separable")
# MPS engines:
# "single" — single-pass (no chunking), one math + one scatter call.
# Default — avoids per-chunk autograd.Function overhead
# and is required for the Metal scatter to win over
# PyTorch scatter_add_ (chunking dilutes that win on
# Apple Silicon, see profiling_mps/ANALYSIS.md).
# "separable_compiled" — CPU separable_compiled (multi-chunk, compiled math).
# "separable" — CPU separable (multi-chunk, eager math).
ISO_MAP_ENGINE_MPS: str = os.environ.get("TORCHREF_ISO_MAP_ENGINE_MPS", "single")
# Legacy env var (backward compat) — only consulted when GPU engine is "auto"
_GPU_MODE = os.environ.get("TORCHREF_ATOM_PLACEMENT_GPU_MODE", "triton")
# Lazy-loaded Triton backends
_fused_fn = None
_fused_checked = False
_separable_fn = None
_separable_checked = False
def _get_fused_triton():
"""Return the fused Triton kernel, or None if unavailable."""
global _fused_fn, _fused_checked
if not _fused_checked:
try:
from torchref.base.kernels.triton_kernel import fused_find_and_place_atoms
_fused_fn = fused_find_and_place_atoms
except ImportError:
pass
_fused_checked = True
return _fused_fn
def _get_separable_triton():
"""Return the separable Triton kernel, or None if unavailable."""
global _separable_fn, _separable_checked
if not _separable_checked:
try:
from torchref.base.kernels.separable_triton_kernel import (
separable_density_gpu,
)
_separable_fn = separable_density_gpu
except ImportError:
pass
_separable_checked = True
return _separable_fn
# Lazy-loaded C++ parallel scatter for CPU
_cpp_scatter_fn = None
_cpp_scatter_checked = False
def _get_cpp_scatter():
"""Return the C++ parallel scatter_add, or None if unavailable.
Eagerly triggers the C++ compilation so that failures (missing ninja,
unsupported compiler flags, etc.) are caught here rather than mid-calculation.
"""
global _cpp_scatter_fn, _cpp_scatter_checked
if not _cpp_scatter_checked:
try:
from torchref.base.kernels.cpu_scatter import structured_scatter_add, _get_module
# Trigger compilation now — _get_module returns None on failure
if _get_module() is not None:
_cpp_scatter_fn = structured_scatter_add
except Exception:
pass
_cpp_scatter_checked = True
return _cpp_scatter_fn
def _do_structured_scatter(
density_cube: torch.Tensor,
wa: torch.Tensor,
wbwc: torch.Tensor,
density_flat: torch.Tensor,
map_size: int,
) -> torch.Tensor:
"""Pick the fastest structured scatter for the device.
On CPU, dispatches to the custom C++ kernel (``cpu_scatter``) when
available — partitioned, no atomics, ~2× faster than PyTorch's stock
``scatter_add_``. On every other device (MPS, CUDA, CPU without the
extension built) falls back to PyTorch ``scatter_add_`` with int64
indices.
Returns the resulting flat density tensor. The C++ path accumulates
out-of-place; the ``scatter_add_`` fallback mutates ``density_flat``
in place. Both return the up-to-date tensor.
"""
if density_cube.device.type == "cpu":
cpp_fn = _get_cpp_scatter()
if cpp_fn is not None:
return density_flat + cpp_fn(density_cube, wa, wbwc, map_size)
# Fallback: PyTorch scatter_add_ requires int64 indices.
idx_flat = wa[:, :, None, None] + wbwc[:, None, :, :]
density_flat.scatter_add_(
0,
idx_flat.reshape(-1).to(torch.int64),
density_cube.reshape(-1),
)
return density_flat
[docs]
def build_electron_density(
real_space_grid: torch.Tensor,
xyz_iso: torch.Tensor,
adp_iso: torch.Tensor,
occ_iso: torch.Tensor,
A_iso: torch.Tensor,
B_iso: torch.Tensor,
inv_frac_matrix: torch.Tensor,
frac_matrix: torch.Tensor,
radius_angstrom: float,
voxel_size: torch.Tensor,
xyz_aniso: Optional[torch.Tensor] = None,
u_aniso: Optional[torch.Tensor] = None,
occ_aniso: Optional[torch.Tensor] = None,
A_aniso: Optional[torch.Tensor] = None,
B_aniso: Optional[torch.Tensor] = None,
dtype: torch.dtype = get_float_dtype(),
) -> torch.Tensor:
"""
Build an electron density map from atomic parameters.
Selects the fastest available backend automatically. On CUDA, tries
the fused Triton kernel first (eliminates find_relevant_voxels), then
falls back to two-step Triton or JIT. On CPU, uses the JIT kernel.
Parameters
----------
real_space_grid : torch.Tensor
Coordinate grid, shape (nx, ny, nz, 3).
xyz_iso : torch.Tensor
Isotropic atom positions, shape (n_iso, 3).
adp_iso : torch.Tensor
Isotropic B-factors, shape (n_iso,).
occ_iso : torch.Tensor
Isotropic occupancies, shape (n_iso,).
A_iso, B_iso : torch.Tensor
ITC92 coefficients, shape (n_iso, 5).
inv_frac_matrix : torch.Tensor
Cartesian-to-fractional matrix, shape (3, 3).
frac_matrix : torch.Tensor
Fractional-to-Cartesian matrix, shape (3, 3).
radius_angstrom : float
Radius around each atom in Angstroms.
voxel_size : torch.Tensor
Voxel dimensions, shape (3,).
xyz_aniso : torch.Tensor, optional
Anisotropic atom positions, shape (n_aniso, 3).
u_aniso : torch.Tensor, optional
Anisotropic U parameters, shape (n_aniso, 6).
occ_aniso : torch.Tensor, optional
Anisotropic occupancies, shape (n_aniso,).
A_aniso, B_aniso : torch.Tensor, optional
ITC92 coefficients for anisotropic atoms, shape (n_aniso, 5).
dtype : torch.dtype, optional
Float dtype for the density map. Default torch.float32.
Returns
-------
torch.Tensor
Electron density map, shape (nx, ny, nz).
"""
device = real_space_grid.device
density_map = torch.zeros(
real_space_grid.shape[:-1], dtype=dtype, device=device,
)
# --- isotropic atoms ---
if len(xyz_iso) > 0:
density_map = _add_isotropic(
real_space_grid, density_map,
xyz_iso, adp_iso, occ_iso, A_iso, B_iso,
inv_frac_matrix, frac_matrix,
radius_angstrom, voxel_size,
)
# --- anisotropic atoms ---
if xyz_aniso is not None and len(xyz_aniso) > 0:
density_map = _add_anisotropic(
real_space_grid, density_map,
xyz_aniso, u_aniso, occ_aniso, A_aniso, B_aniso,
inv_frac_matrix, frac_matrix,
radius_angstrom,
)
return density_map
# =========================================================================
# Internal dispatch helpers
# =========================================================================
def _add_isotropic(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
):
"""Add isotropic atoms using the backend selected by ISO_MAP_ENGINE_*."""
device_type = density_map.device.type
if device_type == "cuda":
return _add_isotropic_gpu(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
)
if device_type == "mps":
return _add_isotropic_mps(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
)
return _add_isotropic_cpu(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
)
def _add_isotropic_mps(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
):
"""MPS dispatch — see ``ISO_MAP_ENGINE_MPS`` above for engine choices."""
engine = ISO_MAP_ENGINE_MPS
grid_shape_tuple = real_space_grid.shape[:3]
if engine == "single":
return _add_isotropic_mps_single(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
)
if engine == "separable_compiled":
return _add_isotropic_cpu_separable_compiled(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
)
if engine == "separable":
return _add_isotropic_cpu_separable(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
)
raise ValueError(
f"Unknown ISO_MAP_ENGINE_MPS={engine!r}. "
f"Choose from: single, separable_compiled, separable"
)
def _add_isotropic_mps_single(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
):
"""Single-pass MPS splat: one math call, one scatter call.
The multi-chunk strategy in ``_add_isotropic_cpu_separable_compiled``
exists so torch.compile sees a small set of fixed shapes across
different protein sizes. Per refinement loop the atom count is
constant — one compile suffices. Eliminating the per-chunk PyTorch
op overhead saves ~10 ms / iter at 1DAW scale on MPS — and that's
the only thing that needs to be different from the CPU path here.
See profiling_mps/ANALYSIS.md for the breakdown.
"""
device = density_map.device
grid_shape = torch.tensor(grid_shape_tuple, device=device)
grid_shape_float = grid_shape.float()
axis_offsets, n_axis = _get_box_radius(voxel_size, radius_angstrom, device)
axis_offsets = axis_offsets.to(dtypes.int)
pi = math.pi
pi_sq = pi * pi
pi_sqrt = math.sqrt(pi)
pi_1p5 = pi * pi_sqrt
G = frac_matrix.T @ frac_matrix
inv_grid = 1.0 / grid_shape_float
nx_val = grid_shape_tuple[0]
ny_val = grid_shape_tuple[1]
nz_val = grid_shape_tuple[2]
ny_nz = ny_val * nz_val
map_size = density_map.numel()
xyz_frac = xyz @ inv_frac_matrix.T
xyz_frac_wrapped = xyz_frac % 1.0
center_idx = torch.round(xyz_frac_wrapped * grid_shape_float).to(dtypes.int)
B_total = ((B + adp[:, None]) * 0.25).clamp(min=0.1)
A_norm = A * occ[:, None] * pi_1p5 / (B_total * torch.sqrt(B_total))
alpha = pi_sq / B_total
tol = 1e-3 * torch.norm(torch.diagonal(G))
has_ab = bool(torch.abs(G[0, 1]) > tol)
has_ac = bool(torch.abs(G[0, 2]) > tol)
has_bc = bool(torch.abs(G[1, 2]) > tol)
axis_offsets_frac = axis_offsets.float().unsqueeze(0) * inv_grid.unsqueeze(1)
# Math for ALL atoms at once
center_frac = center_idx.float() * inv_grid
sub_grid_offset = xyz_frac - center_frac
d_frac = axis_offsets_frac.unsqueeze(0) - sub_grid_offset.unsqueeze(2)
d_frac = d_frac - torch.round(d_frac)
# Compiled separable density — one shape per structure, one compile
density_fn = _get_compiled_separable_density()
density_cube = density_fn(d_frac, alpha, A_norm, G, has_ab, has_ac, has_bc)
# Structured indices for all atoms
all_wa = (center_idx[:, 0:1] + axis_offsets.unsqueeze(0)) % nx_val * ny_nz
all_wb = (center_idx[:, 1:2] + axis_offsets.unsqueeze(0)) % ny_val * nz_val
all_wc = (center_idx[:, 2:3] + axis_offsets.unsqueeze(0)) % nz_val
all_wbwc = all_wb.unsqueeze(2) + all_wc.unsqueeze(1)
# Single scatter call into a flat view of density_map (zero-initialised
# by the caller in build_electron_density).
density_flat = density_map.view(-1)
density_flat = _do_structured_scatter(
density_cube, all_wa, all_wbwc, density_flat, map_size,
)
return density_flat.view(density_map.shape)
def _add_isotropic_gpu(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
):
"""GPU dispatch for isotropic atoms, controlled by ISO_MAP_ENGINE_GPU."""
engine = ISO_MAP_ENGINE_GPU
if engine == "separable_triton":
fn = _get_separable_triton()
if fn is None:
raise RuntimeError("Separable Triton kernel not available")
return fn(
density_map, xyz, adp,
inv_frac_matrix, frac_matrix, A, B, occ,
radius_angstrom,
)
if engine == "fused_triton":
fn = _get_fused_triton()
if fn is None:
raise RuntimeError("Fused Triton kernel not available")
return fn(
real_space_grid, density_map, xyz, adp,
inv_frac_matrix, frac_matrix, A, B, occ,
radius_angstrom, voxel_size,
)
if engine == "original":
return _add_isotropic_original(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom,
)
if engine == "auto":
# Try fused → separable → original. Fused was ~0.26 ms faster fwd+bw
# than separable on A100/1DAW in profile_model_sf benchmarking
# because its larger per-launch kernel cost is more than offset by
# reduced downstream index_put traffic. Separable is kept as a
# robustness fallback for grid configurations where fused trips.
if _GPU_MODE not in ("jit", "simple"):
fused = _get_fused_triton()
if fused is not None:
try:
return fused(
real_space_grid, density_map, xyz, adp,
inv_frac_matrix, frac_matrix, A, B, occ,
radius_angstrom, voxel_size,
)
except Exception:
pass
if _GPU_MODE not in ("jit", "simple"):
separable = _get_separable_triton()
if separable is not None:
try:
return separable(
density_map, xyz, adp,
inv_frac_matrix, frac_matrix, A, B, occ,
radius_angstrom,
)
except Exception:
pass
return _add_isotropic_original(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom,
)
raise ValueError(
f"Unknown ISO_MAP_ENGINE_GPU={engine!r}. "
f"Choose from: auto, separable_triton, fused_triton, original"
)
def _add_isotropic_original(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom,
):
"""Tier 2+: find_relevant_voxels + vectorized_add_to_map."""
from torchref.base.electron_density.voxel_utils import find_relevant_voxels
from torchref.base.kernels import vectorized_add_to_map
surrounding_coords, voxel_indices = find_relevant_voxels(
real_space_grid, xyz,
radius_angstrom=radius_angstrom,
inv_frac_matrix=inv_frac_matrix,
)
return vectorized_add_to_map(
surrounding_coords, voxel_indices, density_map,
xyz, adp, inv_frac_matrix, frac_matrix, A, B, occ,
)
def _add_isotropic_cpu(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size,
):
"""CPU dispatch for isotropic atoms, controlled by ISO_MAP_ENGINE_CPU."""
engine = ISO_MAP_ENGINE_CPU
grid_shape_tuple = real_space_grid.shape[:3]
if engine == "separable":
return _add_isotropic_cpu_separable(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
)
if engine == "separable_compiled":
return _add_isotropic_cpu_separable_compiled(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
)
if engine == "fused":
return _add_isotropic_cpu_fused(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
)
if engine == "original":
return _add_isotropic_original(
real_space_grid, density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom,
)
raise ValueError(
f"Unknown ISO_MAP_ENGINE_CPU={engine!r}. "
f"Choose from: separable, separable_compiled, fused, original"
)
# =========================================================================
# Fused CPU implementation
# =========================================================================
# Cached radius mask to avoid recomputing every call
_cached_radius_offsets = None
_cached_radius_key = None
def _get_radius_offsets(voxel_size, radius_angstrom, device):
"""Get or compute the integer offsets within a spherical radius.
Cached across calls since the result depends only on voxel_size and radius.
"""
global _cached_radius_offsets, _cached_radius_key
key = (tuple(voxel_size.tolist()), float(radius_angstrom), device)
if _cached_radius_key == key and _cached_radius_offsets is not None:
return _cached_radius_offsets
min_voxelsize = voxel_size.min()
box_radius = int(math.ceil(radius_angstrom / min_voxelsize.item()))
r = torch.arange(-box_radius, box_radius + 1, device=device)
gx, gy, gz = torch.meshgrid(r, r, r, indexing="ij")
offsets_all = torch.stack((gx, gy, gz), dim=-1)
dist = torch.sqrt(((offsets_all.float() * voxel_size) ** 2).sum(-1))
local_offsets = offsets_all[dist <= radius_angstrom] # (R, 3) int
_cached_radius_offsets = local_offsets
_cached_radius_key = key
return local_offsets
# =========================================================================
# Separable Gaussian splatting
# =========================================================================
# Cached box radius for separable approach
_cached_separable_data = None
_cached_separable_key = None
def _get_box_radius(voxel_size, radius_angstrom, device):
"""Get axis offsets and box size for separable Gaussian splatting (cached).
Returns
-------
axis_offsets : torch.Tensor
Integer offsets [-box_radius, ..., box_radius], shape (n_axis,).
n_axis : int
Cube side length (2 * box_radius + 1).
"""
global _cached_separable_data, _cached_separable_key
key = (tuple(voxel_size.tolist()), float(radius_angstrom), device)
if _cached_separable_key == key and _cached_separable_data is not None:
return _cached_separable_data
min_voxelsize = voxel_size.min()
box_radius = int(math.ceil(radius_angstrom / min_voxelsize.item()))
n_axis = 2 * box_radius + 1
r = torch.arange(-box_radius, box_radius + 1, device=device)
result = (r, n_axis)
_cached_separable_data = result
_cached_separable_key = key
return result
def _separable_density(
d_frac: torch.Tensor,
alpha: torch.Tensor,
A_norm: torch.Tensor,
G: torch.Tensor,
has_ab: bool,
has_ac: bool,
has_bc: bool,
) -> torch.Tensor:
"""Separable Gaussian density evaluation.
Factorizes exp(-alpha * r^T G r) into 1D Gaussians per axis with 2D
cross-term corrections for non-orthogonal cells. Batches all corrections
across the 5 ITC92 components and uses einsum where possible.
For non-orthogonal cells, cross-term exponents are combined with the
relevant diagonal exponents before taking exp() to avoid float32 overflow
(exp(-big) * exp(+big) = 0 * inf = NaN). Each combined 2D block
exponent corresponds to a principal sub-matrix of G (positive definite),
guaranteeing the exponent is always <= 0 and exp() is in (0, 1].
Dispatch by crystal system for optimal performance:
- Orthogonal (no cross terms): separable 1D products + einsum
- Hexagonal (ab only): combined ab exponent + einsum with e_c
- Monoclinic (ac only): combined ac exponent + einsum with e_b
- General (bc, or multiple cross terms): full 3D exponent per component
Parameters
----------
d_frac : (C, 3, n_axis) — fractional distances per axis, PBC-wrapped.
alpha : (C, N_comp) — pi^2 / B_total.
A_norm : (C, N_comp) — weighted amplitudes.
G : (3, 3) — metric tensor frac_matrix.T @ frac_matrix.
has_ab : bool — whether G[0,1] cross-term is significant.
has_ac : bool — whether G[0,2] cross-term is significant.
has_bc : bool — whether G[1,2] cross-term is significant.
Returns
-------
(C, n_axis, n_axis, n_axis) density cube.
"""
# --- Convert fractional → Cartesian per-axis ---
cell_lengths = torch.sqrt(torch.diagonal(G))
d_cart = d_frac * cell_lengths[None, :, None] # (C, 3, n)
# --- 1D exponents (always <= 0) ---
da2 = d_cart[:, 0, :] ** 2
db2 = d_cart[:, 1, :] ** 2
dc2 = d_cart[:, 2, :] ** 2
log_a = -alpha.unsqueeze(2) * da2.unsqueeze(1) # (C, Nc, n)
log_b = -alpha.unsqueeze(2) * db2.unsqueeze(1)
log_c = -alpha.unsqueeze(2) * dc2.unsqueeze(1)
if not (has_ab or has_ac or has_bc):
# ---- Orthogonal cells: pure separable, all exp() args <= 0 ----
e_a = torch.exp(log_a)
e_b = torch.exp(log_b)
e_c = torch.exp(log_c)
e_ab = e_a.unsqueeze(3) * e_b.unsqueeze(2)
return torch.einsum("cg,cgij,cgk->cijk", A_norm, e_ab, e_c)
# --- Cross-term coefficients ---
cos_gamma = G[0, 1] / (cell_lengths[0] * cell_lengths[1])
cos_beta = G[0, 2] / (cell_lengths[0] * cell_lengths[2])
cos_alpha = G[1, 2] / (cell_lengths[1] * cell_lengths[2])
da = d_cart[:, 0, :]
db = d_cart[:, 1, :]
dc = d_cart[:, 2, :]
alpha_4d = alpha[:, :, None, None] # (C, Nc, 1, 1)
if has_ab and not has_ac and not has_bc:
# ---- Hexagonal / trigonal: only ab cross-term ----
# Combined 2D exponent: -alpha*(da2 + db2 + 2*cos_gamma*da*db)
# = -alpha * d_ab^T G_ab d_ab <= 0 (G_ab positive definite)
prod_ab = da.unsqueeze(2) * db.unsqueeze(1)
log_ab = (log_a[:, :, :, None] + log_b[:, :, None, :]
+ (-2.0 * alpha_4d * cos_gamma * prod_ab[:, None, :, :]))
slice_ab = torch.exp(log_ab) # (C, Nc, n, n), all in (0, 1]
e_c = torch.exp(log_c)
return torch.einsum("cg,cgij,cgk->cijk", A_norm, slice_ab, e_c)
if has_ac and not has_ab and not has_bc:
# ---- Monoclinic (beta != 90): only ac cross-term ----
# Combined 2D exponent: -alpha*(da2 + dc2 + 2*cos_beta*da*dc)
# = -alpha * d_ac^T G_ac d_ac <= 0 (G_ac positive definite)
prod_ac = da.unsqueeze(2) * dc.unsqueeze(1)
log_ac = (log_a[:, :, :, None] + log_c[:, :, None, :]
+ (-2.0 * alpha_4d * cos_beta * prod_ac[:, None, :, :]))
e_ac = torch.exp(log_ac) # (C, Nc, n_a, n_c), all in (0, 1]
e_b = torch.exp(log_b)
return torch.einsum("cg,cgj,cgik->cijk", A_norm, e_b, e_ac)
# ---- General path (triclinic, or multiple cross-terms) ----
# Combine ALL exponents into a single 3D value per voxel per component
# to guarantee no overflow. Component loop keeps memory at O(C*n^3).
prod_ab = da.unsqueeze(2) * db.unsqueeze(1) if has_ab else None
prod_ac = da.unsqueeze(2) * dc.unsqueeze(1) if has_ac else None
prod_bc = db.unsqueeze(2) * dc.unsqueeze(1) if has_bc else None
C = d_frac.shape[0]
n = d_frac.shape[2]
density_cube = d_frac.new_zeros(C, n, n, n)
for g in range(alpha.shape[1]):
# Full 3D exponent: -alpha * r^T G r (always <= 0)
exp_3d = (log_a[:, g, :, None, None]
+ log_b[:, g, None, :, None]
+ log_c[:, g, None, None, :])
if has_ab:
exp_3d = exp_3d + (
-2.0 * alpha[:, g, None, None] * cos_gamma
* prod_ab
).unsqueeze(3) # broadcast (C, n_a, n_b, 1)
if has_ac:
exp_3d = exp_3d + (
-2.0 * alpha[:, g, None, None] * cos_beta
* prod_ac
).unsqueeze(2) # broadcast (C, n_a, 1, n_c)
if has_bc:
exp_3d = exp_3d + (
-2.0 * alpha[:, g, None, None] * cos_alpha
* prod_bc
).unsqueeze(1) # broadcast (C, 1, n_b, n_c)
density_cube += A_norm[:, g, None, None, None] * torch.exp(exp_3d)
return density_cube
_compiled_separable_density = None
def _get_compiled_separable_density():
"""Return a torch.compile'd version of _separable_density (lazy, cached)."""
global _compiled_separable_density
if _compiled_separable_density is None:
_compiled_separable_density = torch.compile(_separable_density)
return _compiled_separable_density
_CHUNK_SIZES = (4096, 2048, 1024, 512)
def _add_isotropic_cpu_separable(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
):
"""Separable Gaussian splatting for isotropic atoms.
Factorizes the 3D Gaussian into 1D components along each fractional axis,
with 2D cross-term corrections for non-zero off-diagonal elements of the
metric tensor. Uses a component loop for non-orthogonal cells to keep peak
memory low (~1.7 MB vs ~22 MB for the full 5D intermediate).
Reduces exp() calls from O(r^3) to O(r) per atom for orthogonal cells,
and O(r^2) for monoclinic/hexagonal cells. Handles all cell geometries.
"""
device = density_map.device
grid_shape = torch.tensor(grid_shape_tuple, device=device)
grid_shape_float = grid_shape.float()
# --- Box radius (cached) ---
axis_offsets, n_axis = _get_box_radius(voxel_size, radius_angstrom, device)
# int32 indices: the cpu_scatter C++ kernel takes int32, and for any
# realistic crystallographic grid (nx*ny*nz < 2**31) all scatter indices
# fit in int32. Halving index bandwidth speeds up the inner loop.
axis_offsets = axis_offsets.to(dtypes.int)
# --- Constants ---
pi = math.pi
pi_sq = pi * pi
pi_sqrt = math.sqrt(pi)
pi_1p5 = pi * pi_sqrt
G = frac_matrix.T @ frac_matrix # metric tensor
inv_grid = 1.0 / grid_shape_float
nx_val = int(grid_shape[0])
ny_val = int(grid_shape[1])
nz_val = int(grid_shape[2])
ny_nz = ny_val * nz_val
# --- Atom fractional coords & center indices ---
xyz_frac = xyz @ inv_frac_matrix.T # (N, 3) — unwrapped, preserves gradients
xyz_frac_wrapped = xyz_frac % 1.0 # only used for index computation
center_idx = torch.round(xyz_frac_wrapped * grid_shape_float).to(dtypes.int) # (N, 3) int32
# --- B_total, normalized amplitudes, and exponent coefficients ---
B_total = ((B + adp[:, None]) * 0.25).clamp(min=0.1) # (N, 5)
A_norm = A * occ[:, None] * pi_1p5 / (B_total * torch.sqrt(B_total)) # (N, 5)
alpha = pi_sq / B_total # (N, 5)
# --- Cross-term flags (computed once) ---
tol = 1e-3 * torch.norm(torch.diagonal(G))
has_ab = bool(torch.abs(G[0, 1]) > tol)
has_ac = bool(torch.abs(G[0, 2]) > tol)
has_bc = bool(torch.abs(G[1, 2]) > tol)
# --- Precompute fractional axis offsets (shared across chunks) ---
# axis_offsets_frac[dim, i] = axis_offsets[i] / grid_shape[dim]
axis_offsets_frac = axis_offsets.float().unsqueeze(0) * inv_grid.unsqueeze(1) # (3, n_axis)
# --- Sort atoms by 1D voxel center for cache-friendly scatter ---
center_1d = (center_idx[:, 0] * ny_nz
+ center_idx[:, 1] * nz_val
+ center_idx[:, 2])
atom_order = torch.argsort(center_1d)
xyz_frac = xyz_frac[atom_order]
center_idx = center_idx[atom_order]
alpha = alpha[atom_order]
A_norm = A_norm[atom_order]
# --- Precompute 1D scatter indices for ALL atoms (int32) ---
# (N, n_axis) each, ~0.2 MB per axis for 3k atoms — avoids recomputing per chunk
all_wa = (center_idx[:, 0:1] + axis_offsets.unsqueeze(0)) % nx_val * ny_nz
all_wb = (center_idx[:, 1:2] + axis_offsets.unsqueeze(0)) % ny_val * nz_val
all_wc = (center_idx[:, 2:3] + axis_offsets.unsqueeze(0)) % nz_val
# 2D yz-plane index: (N, n_axis, n_axis) — cuts outer sum from 2 adds to 1
all_wbwc = all_wb.unsqueeze(2) + all_wc.unsqueeze(1)
# --- Process in chunks ---
N = xyz.shape[0]
CHUNK = 1024
map_size = density_map.numel()
density_flat = density_map.view(-1)
for start in range(0, N, CHUNK):
end = min(start + CHUNK, N)
# Sub-grid offset: fractional displacement from atom to nearest grid point
center_frac = center_idx[start:end].float() * inv_grid # (C, 3)
sub_grid_offset = xyz_frac[start:end] - center_frac # (C, 3)
# 1D fractional distances along each axis: (C, 3, n_axis)
d_frac = axis_offsets_frac.unsqueeze(0) - sub_grid_offset.unsqueeze(2)
d_frac = d_frac - torch.round(d_frac) # PBC
# Density computation → (C, n_axis, n_axis, n_axis)
density_cube = _separable_density(
d_frac,
alpha[start:end],
A_norm[start:end],
G,
has_ab, has_ac, has_bc,
)
density_flat = _do_structured_scatter(
density_cube,
all_wa[start:end],
all_wbwc[start:end],
density_flat,
map_size,
)
return density_flat.view(density_map.shape)
def _add_isotropic_cpu_separable_compiled(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
):
"""Compiled variant of separable Gaussian splatting.
Same algorithm as _add_isotropic_cpu_separable but uses torch.compile on
_separable_density with decreasing fixed chunk sizes (4096, 2048, 1024, 512)
to keep compiled shapes stable across different proteins. The small remainder
runs eagerly to avoid extra recompilation.
"""
device = density_map.device
grid_shape = torch.tensor(grid_shape_tuple, device=device)
grid_shape_float = grid_shape.float()
# --- Box radius (cached) ---
axis_offsets, n_axis = _get_box_radius(voxel_size, radius_angstrom, device)
# int32 indices to match _do_structured_scatter's MPS / CPU C++ kernels;
# the PyTorch scatter_add_ fallback casts to int64 inside the helper.
axis_offsets = axis_offsets.to(dtypes.int)
# --- Constants ---
pi = math.pi
pi_sq = pi * pi
pi_sqrt = math.sqrt(pi)
pi_1p5 = pi * pi_sqrt
G = frac_matrix.T @ frac_matrix # metric tensor
inv_grid = 1.0 / grid_shape_float
nx_val = int(grid_shape[0])
ny_val = int(grid_shape[1])
nz_val = int(grid_shape[2])
ny_nz = ny_val * nz_val
map_size = density_map.numel()
# --- Atom fractional coords & center indices ---
xyz_frac = xyz @ inv_frac_matrix.T # (N, 3) — unwrapped, preserves gradients
xyz_frac_wrapped = xyz_frac % 1.0 # only used for index computation
center_idx = torch.round(xyz_frac_wrapped * grid_shape_float).to(dtypes.int) # (N, 3) int32
# --- B_total, normalized amplitudes, and exponent coefficients ---
B_total = ((B + adp[:, None]) * 0.25).clamp(min=0.1) # (N, 5)
A_norm = A * occ[:, None] * pi_1p5 / (B_total * torch.sqrt(B_total)) # (N, 5)
alpha = pi_sq / B_total # (N, 5)
# --- Cross-term flags (computed once, passed as compile-time constants) ---
tol = 1e-3 * torch.norm(torch.diagonal(G))
has_ab = bool(torch.abs(G[0, 1]) > tol)
has_ac = bool(torch.abs(G[0, 2]) > tol)
has_bc = bool(torch.abs(G[1, 2]) > tol)
# --- Precompute fractional axis offsets (shared across chunks) ---
axis_offsets_frac = axis_offsets.float().unsqueeze(0) * inv_grid.unsqueeze(1) # (3, n_axis)
# --- Process with decreasing fixed chunk sizes for stable compiled shapes ---
N = xyz.shape[0]
compiled_fn = _get_compiled_separable_density()
density_flat = density_map.view(-1)
offset = 0
remaining = N
for chunk_size in _CHUNK_SIZES:
while remaining >= chunk_size:
end = offset + chunk_size
density_flat = _splat_chunk(
offset, end, center_idx, xyz_frac, axis_offsets_frac, inv_grid,
alpha, A_norm, G, has_ab, has_ac, has_bc,
axis_offsets, nx_val, ny_val, nz_val, ny_nz, map_size,
density_flat, compiled_fn,
)
offset = end
remaining -= chunk_size
# --- Eager remainder (no recompilation for the tail) ---
if remaining > 0:
density_flat = _splat_chunk(
offset, offset + remaining, center_idx, xyz_frac,
axis_offsets_frac, inv_grid, alpha, A_norm, G,
has_ab, has_ac, has_bc, axis_offsets,
nx_val, ny_val, nz_val, ny_nz, map_size,
density_flat, _separable_density,
)
return density_flat.view(density_map.shape)
def _splat_chunk(
start, end, center_idx, xyz_frac, axis_offsets_frac, inv_grid,
alpha, A_norm, G, has_ab, has_ac, has_bc,
axis_offsets, nx_val, ny_val, nz_val, ny_nz, map_size,
density_flat, density_fn,
):
"""Compute separable density for one chunk and scatter into the map.
Returns the (possibly new) flat density tensor — the C++ and MPS
structured-scatter backends produce out-of-place results that we have
to rebind through the chunk loop.
"""
# Sub-grid offset: fractional displacement from atom to nearest grid point
center_frac = center_idx[start:end].float() * inv_grid # (C, 3)
sub_grid_offset = xyz_frac[start:end] - center_frac # (C, 3)
# 1D fractional distances along each axis: (C, 3, n_axis)
d_frac = axis_offsets_frac.unsqueeze(0) - sub_grid_offset.unsqueeze(2)
d_frac = d_frac - torch.round(d_frac) # PBC
# Density computation → (C, n_axis, n_axis, n_axis)
density_cube = density_fn(
d_frac, alpha[start:end], A_norm[start:end],
G, has_ab, has_ac, has_bc,
)
# Structured (wa, wbwc) indices — int32; cast to int64 happens inside
# the helper's PyTorch fallback.
chunk_center = center_idx[start:end]
wa = (chunk_center[:, 0:1] + axis_offsets.unsqueeze(0)) % nx_val * ny_nz # (C, n)
wb = (chunk_center[:, 1:2] + axis_offsets.unsqueeze(0)) % ny_val * nz_val
wc = (chunk_center[:, 2:3] + axis_offsets.unsqueeze(0)) % nz_val
wbwc = wb.unsqueeze(2) + wc.unsqueeze(1) # (C, n, n)
return _do_structured_scatter(
density_cube, wa, wbwc, density_flat, map_size,
)
def _add_isotropic_cpu_fused(
density_map, xyz, adp, occ, A, B,
inv_frac_matrix, frac_matrix,
grid_shape_tuple, voxel_size, radius_angstrom,
):
"""Fused CPU path for isotropic atoms.
Avoids the expensive Cartesian↔Fractional round-trip by computing voxel
fractional coordinates directly from integer indices. Processes atoms in
chunks to keep intermediates in L3 cache.
"""
device = density_map.device
grid_shape = torch.tensor(grid_shape_tuple, device=device)
grid_shape_float = grid_shape.float()
# --- Radius mask (cached) ---
local_offsets = _get_radius_offsets(voxel_size, radius_angstrom, device)
# --- Constants ---
pi = math.pi
pi_sq = pi * pi
pi_sqrt = math.sqrt(pi)
pi_1p5 = pi * pi_sqrt
G = frac_matrix.T @ frac_matrix # metric tensor
inv_grid = 1.0 / grid_shape_float
ny_nz = int(grid_shape[1]) * int(grid_shape[2])
nz_val = int(grid_shape[2])
strides = torch.tensor([ny_nz, nz_val, 1], device=device, dtype=torch.long)
# --- Atom fractional coords & center indices ---
xyz_frac = xyz @ inv_frac_matrix.T # (N, 3) — unwrapped, preserves gradients
xyz_frac_wrapped = xyz_frac % 1.0 # only used for index computation
center_idx = torch.round(xyz_frac_wrapped * grid_shape_float).long() # (N, 3)
# --- B_total and normalized amplitudes (small, atom-level tensors) ---
B_total = ((B + adp[:, None]) * 0.25).clamp(min=0.1) # (N, 5)
A_norm = A * occ[:, None] * pi_1p5 / (B_total * torch.sqrt(B_total)) # (N, 5)
# --- Process in chunks for cache efficiency ---
N = xyz.shape[0]
CHUNK = 1024
for start in range(0, N, CHUNK):
end = min(start + CHUNK, N)
# Voxel indices (wrapped for scatter and frac coord computation)
vi = (center_idx[start:end].unsqueeze(1) + local_offsets.unsqueeze(0)) % grid_shape
# shape: (C, R, 3)
# Fractional voxel positions — direct from integer indices
voxel_frac = vi.float() * inv_grid # (C, R, 3)
# Fractional diff with PBC — use unwrapped xyz_frac to preserve gradients
diff_frac = voxel_frac - xyz_frac[start:end].unsqueeze(1)
diff_frac = diff_frac - torch.round(diff_frac)
# r² via metric tensor: exact for any cell geometry
r_sq = torch.einsum("avi,ij,avj->av", diff_frac, G, diff_frac)
# Gaussian density
chunk_B = B_total[start:end]
exponents = -pi_sq * r_sq.unsqueeze(2) / chunk_B.unsqueeze(1)
density = torch.einsum(
"ag,avg->av", A_norm[start:end], torch.exp(exponents)
)
# Scatter add to map
idx_flat = (vi.to(torch.long) * strides).sum(-1).view(-1)
density_map.view(-1).scatter_add_(0, idx_flat, density.reshape(-1))
return density_map
def _add_anisotropic(
real_space_grid, density_map, xyz, u, occ, A, B,
inv_frac_matrix, frac_matrix, radius_angstrom,
):
"""Add anisotropic atoms (always two-step, no Triton kernel yet)."""
from torchref.base.electron_density.voxel_utils import find_relevant_voxels
from torchref.base.electron_density.map_building import vectorized_add_to_map_aniso
surrounding_coords, voxel_indices = find_relevant_voxels(
real_space_grid, xyz,
radius_angstrom=radius_angstrom,
inv_frac_matrix=inv_frac_matrix,
)
return vectorized_add_to_map_aniso(
surrounding_coords, voxel_indices, density_map,
xyz, u, inv_frac_matrix, frac_matrix, A, B, occ,
)