torchref.base.kernels.triton_kernel module
Triton GPU kernel for fused electron density computation.
Fuses PBC wrapping, r² calculation, 5-Gaussian evaluation, and scatter-add into a single GPU kernel, eliminating ~14 separate kernel launches and ~500MB of intermediate memory allocations.
Provides full autograd support for refinement of xyz, b, and occ.
- torchref.base.kernels.triton_kernel.fused_add_to_map_gpu(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]
Fused GPU density computation using Triton.
Drop-in replacement for the JIT GPU kernel with full autograd support. Fuses PBC wrapping, r² computation, 5-Gaussian evaluation, and scatter-add into a single GPU kernel launch.
- Parameters:
surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Grid indices of voxels, shape (N_atoms, N_voxels, 3).
density_map (torch.Tensor) – Electron density map to update in-place, shape (nx, ny, nz).
xyz (torch.Tensor) – Atom positions in Cartesian space, shape (N_atoms, 3).
b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).
frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).
A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).
B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).
occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).
- Returns:
Updated density map (modified in-place).
- Return type:
- torchref.base.kernels.triton_kernel.fused_find_and_place_atoms(real_space_grid, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, radius_angstrom, voxel_size)[source]
Fused voxel-finding + density computation using Triton.
Eliminates the separate find_relevant_voxels step and the large surrounding_coords / voxel_indices intermediate tensors. Computes center grid indices and spherical voxel offsets directly inside the GPU kernel.
- Parameters:
real_space_grid (torch.Tensor) – Real space coordinate grid, shape (nx, ny, nz, 3).
density_map (torch.Tensor) – Density map to update, shape (nx, ny, nz).
xyz (Tensor) – Same as fused_add_to_map_gpu.
b (Tensor) – Same as fused_add_to_map_gpu.
inv_frac_matrix (Tensor) – Same as fused_add_to_map_gpu.
frac_matrix (Tensor) – Same as fused_add_to_map_gpu.
A (Tensor) – Same as fused_add_to_map_gpu.
B (Tensor) – Same as fused_add_to_map_gpu.
occ (Tensor) – Same as fused_add_to_map_gpu.
radius_angstrom (float) – Radius around each atom in Angstroms.
voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).
- Returns:
Updated density map.
- Return type: