torchref.base.kernels package

Optimized kernels for electron density computation.

This submodule provides optimized implementations for compute-intensive operations in crystallographic calculations:

  • JIT-compiled PyTorch kernels for CPU and GPU

  • Triton CUDA kernels for fused operations

  • Optimized fused operations with reduced kernel launches

torchref.base.kernels.vectorized_add_to_map(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Add atoms to density map using ITC92 Gaussian parameterization.

Automatically selects the optimal implementation based on device. GPU default: Triton fused kernel (3-6x faster, falls back to JIT if Triton is unavailable). Override with TORCHREF_ATOM_PLACEMENT_GPU_MODE=jit or simple.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map (modified in-place).

Return type:

torch.Tensor

torchref.base.kernels.build_electron_density(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Build electron density map from atomic parameters.

This is an alias for vectorized_add_to_map for semantic clarity.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map.

Return type:

torch.Tensor

torchref.base.kernels.compute_metric_tensor(frac_matrix)[source]

Compute the metric tensor for calculating r² in fractional coordinates.

The metric tensor G allows computing squared distances in Cartesian space from fractional coordinate differences:

r² = diff_frac @ G @ diff_frac.T

Parameters:

frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

Returns:

Metric tensor G = frac_matrix.T @ frac_matrix, shape (3, 3).

Return type:

torch.Tensor

torchref.base.kernels.precompute_fractional_coords(coords_cart, inv_frac_matrix)[source]

Convert Cartesian voxel coordinates to fractional coordinates.

Parameters:
  • coords_cart (torch.Tensor) – Cartesian coordinates, shape (N_atoms, N_voxels, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

Returns:

Fractional coordinates, shape (N_atoms, N_voxels, 3).

Return type:

torch.Tensor

torchref.base.kernels.warmup(device='auto')[source]

Pre-compile kernels to avoid compilation overhead during first use.

Parameters:

device (str) – Device to warmup: “cpu”, “cuda”, or “auto” (default).

torchref.base.kernels.get_cache_dir()[source]

Return the path to the JIT kernel cache directory.

torchref.base.kernels.clear_cache()[source]

Clear the JIT kernel cache.

torchref.base.kernels.warmup_cuda_operations(device='cuda')[source]

Warm up CUDA kernels to avoid lazy loading overhead.

This function runs dummy operations to trigger CUDA kernel compilation and loading, so subsequent operations don’t incur this overhead.

Call this once after moving model to GPU.

Parameters:

device (str) – Device to warm up. Default is “cuda”.

class torchref.base.kernels.CachedRadiusMask[source]

Bases: object

Cache the radius mask computation to avoid recomputing for every atom batch.

This eliminates redundant computation when processing multiple atoms with the same voxel size and radius.

Usage

>>> cache = CachedRadiusMask()
>>> offsets = cache.get_offsets(voxel_size, radius_angstrom, device)
param None:

_cache

Internal cache storing computed offsets.

Type:

dict

__init__()[source]
get_offsets(voxel_size, radius_angstrom, device)[source]

Get cached offset grid for given parameters.

Parameters:
  • voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).

  • radius_angstrom (float) – Radius in Angstroms.

  • device (torch.device) – Device for the output tensor.

Returns:

Voxel offsets within radius, shape (N_voxels, 3).

Return type:

torch.Tensor

torchref.base.kernels.get_cached_radius_offsets(voxel_size, radius_angstrom, device)[source]

Get cached radius offsets to avoid recomputation.

This eliminates redundant computation when processing multiple atoms with the same voxel size and radius.

Parameters:
  • voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).

  • radius_angstrom (float) – Radius in Angstroms.

  • device (torch.device) – Device for the output tensor.

Returns:

Voxel offsets within radius, shape (N_voxels, 3).

Return type:

torch.Tensor

torchref.base.kernels.vectorized_add_to_map_optimized(surrounding_coords, voxel_indices, map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Optimized version of vectorized_add_to_map using fused Gaussian calculation.

This is a drop-in replacement that uses the fused_gaussian_density function to reduce kernel launches.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map.

Return type:

torch.Tensor

torchref.base.kernels.fused_add_to_map_gpu(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Fused GPU density computation using Triton.

Drop-in replacement for the JIT GPU kernel with full autograd support. Fuses PBC wrapping, r² computation, 5-Gaussian evaluation, and scatter-add into a single GPU kernel launch.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Grid indices of voxels, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update in-place, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian space, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated density map (modified in-place).

Return type:

torch.Tensor

torchref.base.kernels.fused_find_and_place_atoms(real_space_grid, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, radius_angstrom, voxel_size)[source]

Fused voxel-finding + density computation using Triton.

Eliminates the separate find_relevant_voxels step and the large surrounding_coords / voxel_indices intermediate tensors. Computes center grid indices and spherical voxel offsets directly inside the GPU kernel.

Parameters:
  • real_space_grid (torch.Tensor) – Real space coordinate grid, shape (nx, ny, nz, 3).

  • density_map (torch.Tensor) – Density map to update, shape (nx, ny, nz).

  • xyz (Tensor) – Same as fused_add_to_map_gpu.

  • b (Tensor) – Same as fused_add_to_map_gpu.

  • inv_frac_matrix (Tensor) – Same as fused_add_to_map_gpu.

  • frac_matrix (Tensor) – Same as fused_add_to_map_gpu.

  • A (Tensor) – Same as fused_add_to_map_gpu.

  • B (Tensor) – Same as fused_add_to_map_gpu.

  • occ (Tensor) – Same as fused_add_to_map_gpu.

  • radius_angstrom (float) – Radius around each atom in Angstroms.

  • voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).

Returns:

Updated density map.

Return type:

torch.Tensor

torchref.base.kernels.separable_density_gpu(density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, radius_angstrom)[source]

Separable Gaussian density splatting on GPU via Triton.

Eliminates the real_space_grid tensor and PBC matrix operations by working directly in fractional space with the metric tensor. Precomputes 1D Gaussian tables per atom and gathers per sphere voxel.

Parameters:
  • density_map ((nx, ny, nz) — density grid to update (not modified in-place))

  • xyz ((N_atoms, 3) — Cartesian positions)

  • b ((N_atoms,) — isotropic B-factors)

  • inv_frac_matrix ((3, 3) — Cartesian→fractional)

  • frac_matrix ((3, 3) — fractional→Cartesian)

  • A ((N_atoms, 5) — ITC92 amplitudes)

  • B ((N_atoms, 5) — ITC92 widths)

  • occ ((N_atoms,) — occupancies)

  • radius_angstrom (float — cutoff radius)

Return type:

torch.Tensor — updated density map

Submodules