torchref.base.kernels.triton_kernel module

Triton GPU kernel for fused electron density computation.

Fuses PBC wrapping, r² calculation, 5-Gaussian evaluation, and scatter-add into a single GPU kernel, eliminating ~14 separate kernel launches and ~500MB of intermediate memory allocations.

Provides full autograd support for refinement of xyz, b, and occ.

torchref.base.kernels.triton_kernel.fused_add_to_map_gpu(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Fused GPU density computation using Triton.

Drop-in replacement for the JIT GPU kernel with full autograd support. Fuses PBC wrapping, r² computation, 5-Gaussian evaluation, and scatter-add into a single GPU kernel launch.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Grid indices of voxels, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update in-place, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian space, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated density map (modified in-place).

Return type:

torch.Tensor

torchref.base.kernels.triton_kernel.fused_find_and_place_atoms(real_space_grid, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, radius_angstrom, voxel_size)[source]

Fused voxel-finding + density computation using Triton.

Eliminates the separate find_relevant_voxels step and the large surrounding_coords / voxel_indices intermediate tensors. Computes center grid indices and spherical voxel offsets directly inside the GPU kernel.

Parameters:
  • real_space_grid (torch.Tensor) – Real space coordinate grid, shape (nx, ny, nz, 3).

  • density_map (torch.Tensor) – Density map to update, shape (nx, ny, nz).

  • xyz (Tensor) – Same as fused_add_to_map_gpu.

  • b (Tensor) – Same as fused_add_to_map_gpu.

  • inv_frac_matrix (Tensor) – Same as fused_add_to_map_gpu.

  • frac_matrix (Tensor) – Same as fused_add_to_map_gpu.

  • A (Tensor) – Same as fused_add_to_map_gpu.

  • B (Tensor) – Same as fused_add_to_map_gpu.

  • occ (Tensor) – Same as fused_add_to_map_gpu.

  • radius_angstrom (float) – Radius around each atom in Angstroms.

  • voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).

Returns:

Updated density map.

Return type:

torch.Tensor