torchref.base.kernels.jit_kernel_vectorized_add_to_map module

Optimized vectorized_add_to_map with automatic CPU/GPU path selection.

This module provides optimized implementations for adding atoms to a density map using ITC92 Gaussian parameterization. It automatically selects the best implementation based on the device (CPU or GPU) and uses JIT-compiled kernels for optimal performance.

Architecture: - CPU: JIT-scripted kernel using einsum with metric tensor (efficient for CPU) - GPU: JIT-scripted kernel using batch matmul (efficient for GPU)

Both implementations are fully differentiable and compile on import for minimal first-call overhead.

Usage:

from torchref.base.kernels import vectorized_add_to_map

# Automatically selects CPU or GPU implementation based on tensor device density_map = vectorized_add_to_map(

surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ

)

torchref.base.kernels.jit_kernel_vectorized_add_to_map.vectorized_add_to_map(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Add atoms to density map using ITC92 Gaussian parameterization.

Automatically selects the optimal implementation based on device. GPU default: Triton fused kernel (3-6x faster, falls back to JIT if Triton is unavailable). Override with TORCHREF_ATOM_PLACEMENT_GPU_MODE=jit or simple.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map (modified in-place).

Return type:

torch.Tensor

torchref.base.kernels.jit_kernel_vectorized_add_to_map.build_electron_density(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Build electron density map from atomic parameters.

This is an alias for vectorized_add_to_map for semantic clarity.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map.

Return type:

torch.Tensor

torchref.base.kernels.jit_kernel_vectorized_add_to_map.compute_metric_tensor(frac_matrix)[source]

Compute the metric tensor for calculating r² in fractional coordinates.

The metric tensor G allows computing squared distances in Cartesian space from fractional coordinate differences:

r² = diff_frac @ G @ diff_frac.T

Parameters:

frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

Returns:

Metric tensor G = frac_matrix.T @ frac_matrix, shape (3, 3).

Return type:

torch.Tensor

torchref.base.kernels.jit_kernel_vectorized_add_to_map.precompute_fractional_coords(coords_cart, inv_frac_matrix)[source]

Convert Cartesian voxel coordinates to fractional coordinates.

Parameters:
  • coords_cart (torch.Tensor) – Cartesian coordinates, shape (N_atoms, N_voxels, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

Returns:

Fractional coordinates, shape (N_atoms, N_voxels, 3).

Return type:

torch.Tensor

torchref.base.kernels.jit_kernel_vectorized_add_to_map.warmup(device='auto')[source]

Pre-compile kernels to avoid compilation overhead during first use.

Parameters:

device (str) – Device to warmup: “cpu”, “cuda”, or “auto” (default).

torchref.base.kernels.jit_kernel_vectorized_add_to_map.get_cache_dir()[source]

Return the path to the JIT kernel cache directory.

torchref.base.kernels.jit_kernel_vectorized_add_to_map.clear_cache()[source]

Clear the JIT kernel cache.