torchref.base.kernels.jit_kernel_vectorized_add_to_map module
Optimized vectorized_add_to_map with automatic CPU/GPU path selection.
This module provides optimized implementations for adding atoms to a density map using ITC92 Gaussian parameterization. It automatically selects the best implementation based on the device (CPU or GPU) and uses JIT-compiled kernels for optimal performance.
Architecture: - CPU: JIT-scripted kernel using einsum with metric tensor (efficient for CPU) - GPU: JIT-scripted kernel using batch matmul (efficient for GPU)
Both implementations are fully differentiable and compile on import for minimal first-call overhead.
- Usage:
from torchref.base.kernels import vectorized_add_to_map
# Automatically selects CPU or GPU implementation based on tensor device density_map = vectorized_add_to_map(
surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ
)
- torchref.base.kernels.jit_kernel_vectorized_add_to_map.vectorized_add_to_map(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]
Add atoms to density map using ITC92 Gaussian parameterization.
Automatically selects the optimal implementation based on device. GPU default: Triton fused kernel (3-6x faster, falls back to JIT if Triton is unavailable). Override with TORCHREF_ATOM_PLACEMENT_GPU_MODE=jit or simple.
- Parameters:
surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).
density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).
xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).
b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).
frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).
A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).
B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).
occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).
- Returns:
Updated electron density map (modified in-place).
- Return type:
- torchref.base.kernels.jit_kernel_vectorized_add_to_map.build_electron_density(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]
Build electron density map from atomic parameters.
This is an alias for vectorized_add_to_map for semantic clarity.
- Parameters:
surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).
density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).
xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).
b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).
frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).
A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).
B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).
occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).
- Returns:
Updated electron density map.
- Return type:
- torchref.base.kernels.jit_kernel_vectorized_add_to_map.compute_metric_tensor(frac_matrix)[source]
Compute the metric tensor for calculating r² in fractional coordinates.
The metric tensor G allows computing squared distances in Cartesian space from fractional coordinate differences:
r² = diff_frac @ G @ diff_frac.T
- Parameters:
frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).
- Returns:
Metric tensor G = frac_matrix.T @ frac_matrix, shape (3, 3).
- Return type:
- torchref.base.kernels.jit_kernel_vectorized_add_to_map.precompute_fractional_coords(coords_cart, inv_frac_matrix)[source]
Convert Cartesian voxel coordinates to fractional coordinates.
- Parameters:
coords_cart (torch.Tensor) – Cartesian coordinates, shape (N_atoms, N_voxels, 3).
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).
- Returns:
Fractional coordinates, shape (N_atoms, N_voxels, 3).
- Return type:
- torchref.base.kernels.jit_kernel_vectorized_add_to_map.warmup(device='auto')[source]
Pre-compile kernels to avoid compilation overhead during first use.
- Parameters:
device (str) – Device to warmup: “cpu”, “cuda”, or “auto” (default).