Source code for torchref.base.kernels.optimized_ops

"""
Optimized versions of map building functions with kernel fusion
to reduce CPU-GPU synchronization overhead.
"""

import torch


@torch.jit.script
def fused_gaussian_density(
    diff_coords_squared: torch.Tensor,
    B: torch.Tensor,
    b: torch.Tensor,
    A: torch.Tensor,
    occ: torch.Tensor,
) -> torch.Tensor:
    """
    Fused computation of Gaussian density to reduce kernel launches.

    This combines:
    - B_total calculation
    - Normalization
    - A normalization
    - Gaussian evaluation
    - Summation over components

    Into a single fused operation that PyTorch can optimize.

    Parameters
    ----------
    diff_coords_squared : torch.Tensor
        Squared distances from atoms to voxels, shape (N_atoms, N_voxels).
    B : torch.Tensor
        ITC92 B parameters, shape (N_atoms, 5).
    b : torch.Tensor
        Atomic B-factors, shape (N_atoms,).
    A : torch.Tensor
        ITC92 A parameters, shape (N_atoms, 5).
    occ : torch.Tensor
        Occupancies, shape (N_atoms,).

    Returns
    -------
    torch.Tensor
        Computed density values, shape (N_atoms, N_voxels).
    """
    pi = 3.14159265359
    pi_sq = 9.86960440109

    # All operations in one expression for better fusion
    B_total = torch.clamp((B + b.unsqueeze(1)) / 4.0, min=0.1)

    # Compute in a way that encourages fusion
    normalization = torch.pow(pi / B_total, 1.5)
    A_norm = A * occ.unsqueeze(1) * normalization

    # Gaussian evaluation
    exponent = -pi_sq * diff_coords_squared.unsqueeze(2) / B_total.unsqueeze(1)
    gaussian = torch.exp(exponent)

    # Sum over Gaussian components
    density = torch.sum(A_norm.unsqueeze(1) * gaussian, dim=2)

    return density


@torch.jit.script
def fused_aniso_gaussian_density(
    diff_coords: torch.Tensor,
    U_matrix: torch.Tensor,
    A: torch.Tensor,
    occ: torch.Tensor,
) -> torch.Tensor:
    """
    Fused anisotropic Gaussian density calculation.

    Parameters
    ----------
    diff_coords : torch.Tensor
        Distance vectors, shape (N_atoms, N_voxels, 3).
    U_matrix : torch.Tensor
        Anisotropic U tensors for each Gaussian component, shape (N_atoms, 4, 3, 3).
    A : torch.Tensor
        ITC92 A parameters, shape (N_atoms, 4).
    occ : torch.Tensor
        Occupancies, shape (N_atoms,).

    Returns
    -------
    torch.Tensor
        Computed density values, shape (N_atoms, N_voxels).
    """
    two_pi_sq = 19.7392088022

    # Compute r^T U r for all components at once
    # diff_coords: (N_atoms, N_voxels, 3)
    # U_matrix: (N_atoms, 4, 3, 3)

    # Expand for broadcasting
    diff_expanded = diff_coords.unsqueeze(2)  # (N_atoms, N_voxels, 1, 3)

    # Matrix multiply: r^T U r
    U_expanded = U_matrix.unsqueeze(1)  # (N_atoms, 1, 4, 3, 3)

    # r^T U
    rT_U = torch.matmul(diff_expanded, U_expanded)  # (N_atoms, N_voxels, 4, 1, 3)

    # (r^T U) r
    quad_form = torch.matmul(
        rT_U, diff_expanded.unsqueeze(-1)
    )  # (N_atoms, N_voxels, 4, 1, 1)
    quad_form = quad_form.squeeze(-1).squeeze(-1)  # (N_atoms, N_voxels, 4)

    # Compute Gaussian
    exponent = -two_pi_sq * quad_form
    gaussian = torch.exp(exponent)

    # Weight by amplitude and occupancy
    A_occ = A * occ.unsqueeze(1)  # (N_atoms, 4)
    density = torch.sum(A_occ.unsqueeze(1) * gaussian, dim=2)

    return density


[docs] def warmup_cuda_operations(device: str = "cuda") -> None: """ Warm up CUDA kernels to avoid lazy loading overhead. This function runs dummy operations to trigger CUDA kernel compilation and loading, so subsequent operations don't incur this overhead. Call this once after moving model to GPU. Parameters ---------- device : str Device to warm up. Default is "cuda". """ if device == "cpu": return # Create dummy tensors with correct shapes for broadcasting dummy_a = torch.randn(1000, 100, device=device) dummy_b = torch.randn(1000, 100, device=device) dummy_c = torch.randn(1000, device=device) dummy_d = torch.randn(1000, 5, device=device) # Trigger common operations _ = dummy_a + dummy_b _ = dummy_a * dummy_c.unsqueeze(1) _ = torch.exp(dummy_a) _ = torch.sum(dummy_a, dim=1) _ = dummy_a / dummy_b.clamp(min=0.1) _ = torch.matmul(dummy_d, dummy_d.T) # FFT operations dummy_3d = torch.randn(64, 64, 64, device=device, dtype=torch.complex64) _ = torch.fft.fftn(dummy_3d) _ = torch.fft.ifftn(dummy_3d) # Scatter operations dummy_map = torch.zeros(100, 100, 100, device=device) dummy_indices = torch.randint(0, 100, (1000, 3), device=device) dummy_values = torch.randn(1000, device=device) dummy_map.view(-1).index_add_( 0, dummy_indices[:, 0] * 10000 + dummy_indices[:, 1] * 100 + dummy_indices[:, 2], dummy_values, ) # Synchronize to ensure all kernels are loaded torch.cuda.synchronize()
@torch.jit.script def compute_smallest_diff_squared( diff: torch.Tensor, inv_frac_matrix: torch.Tensor, frac_matrix: torch.Tensor ) -> torch.Tensor: """ Fused computation of periodic distance squared. Combines fractional coordinate conversion, wrapping, and distance calculation into a single fused operation. Parameters ---------- diff : torch.Tensor Difference vectors in Cartesian coordinates. inv_frac_matrix : torch.Tensor Inverse fractionalization matrix, shape (3, 3). frac_matrix : torch.Tensor Fractionalization matrix, shape (3, 3). Returns ------- torch.Tensor Squared distances with periodic boundary conditions applied. """ # Convert to fractional diff_frac = torch.matmul(diff, inv_frac_matrix.T) # Wrap to [-0.5, 0.5] diff_frac_wrapped = diff_frac - torch.round(diff_frac) # Convert back to Cartesian diff_cart = torch.matmul(diff_frac_wrapped, frac_matrix.T) # Compute squared distance r_squared = torch.sum(diff_cart * diff_cart, dim=-1) return r_squared
[docs] class CachedRadiusMask: """ Cache the radius mask computation to avoid recomputing for every atom batch. This eliminates redundant computation when processing multiple atoms with the same voxel size and radius. Usage ----- >>> cache = CachedRadiusMask() >>> offsets = cache.get_offsets(voxel_size, radius_angstrom, device) Parameters ---------- None Attributes ---------- _cache : dict Internal cache storing computed offsets. """
[docs] def __init__(self): self._cache = {}
[docs] def get_offsets( self, voxel_size: torch.Tensor, radius_angstrom: float, device: torch.device ) -> torch.Tensor: """ Get cached offset grid for given parameters. Parameters ---------- voxel_size : torch.Tensor Voxel dimensions, shape (3,). radius_angstrom : float Radius in Angstroms. device : torch.device Device for the output tensor. Returns ------- torch.Tensor Voxel offsets within radius, shape (N_voxels, 3). """ # Create cache key voxel_min = voxel_size.min().item() key = (device, radius_angstrom, round(voxel_min, 6)) if key not in self._cache: # Compute radius in voxels min_box_radius = int(torch.ceil(radius_angstrom / voxel_size.min()).item()) # Create offset grid gridx = torch.arange(-min_box_radius, min_box_radius + 1, device=device) gridy = torch.arange(-min_box_radius, min_box_radius + 1, device=device) gridz = torch.arange(-min_box_radius, min_box_radius + 1, device=device) x, y, z = torch.meshgrid(gridx, gridy, gridz, indexing="ij") coords = torch.stack((x, y, z), dim=-1) # Compute distances distance_map = torch.sqrt( torch.sum((coords * voxel_size.unsqueeze(0)) ** 2, dim=-1) ) within_radius_mask = distance_map <= radius_angstrom # Store offsets self._cache[key] = coords[within_radius_mask].contiguous() return self._cache[key]
# Global cache instance _radius_mask_cache = CachedRadiusMask()
[docs] def get_cached_radius_offsets( voxel_size: torch.Tensor, radius_angstrom: float, device: torch.device ) -> torch.Tensor: """ Get cached radius offsets to avoid recomputation. This eliminates redundant computation when processing multiple atoms with the same voxel size and radius. Parameters ---------- voxel_size : torch.Tensor Voxel dimensions, shape (3,). radius_angstrom : float Radius in Angstroms. device : torch.device Device for the output tensor. Returns ------- torch.Tensor Voxel offsets within radius, shape (N_voxels, 3). """ return _radius_mask_cache.get_offsets(voxel_size, radius_angstrom, device)
[docs] def vectorized_add_to_map_optimized( surrounding_coords: torch.Tensor, voxel_indices: torch.Tensor, map: torch.Tensor, xyz: torch.Tensor, b: torch.Tensor, inv_frac_matrix: torch.Tensor, frac_matrix: torch.Tensor, A: torch.Tensor, B: torch.Tensor, occ: torch.Tensor, ) -> torch.Tensor: """ Optimized version of vectorized_add_to_map using fused Gaussian calculation. This is a drop-in replacement that uses the fused_gaussian_density function to reduce kernel launches. Parameters ---------- surrounding_coords : torch.Tensor Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3). voxel_indices : torch.Tensor Indices of voxels in the map, shape (N_atoms, N_voxels, 3). map : torch.Tensor Electron density map to update, shape (nx, ny, nz). xyz : torch.Tensor Atom positions in Cartesian coordinates, shape (N_atoms, 3). b : torch.Tensor Isotropic B-factors, shape (N_atoms,). inv_frac_matrix : torch.Tensor Inverse fractionalization matrix, shape (3, 3). frac_matrix : torch.Tensor Fractionalization matrix, shape (3, 3). A : torch.Tensor ITC92 amplitude coefficients, shape (N_atoms, 5). B : torch.Tensor ITC92 width coefficients, shape (N_atoms, 5). occ : torch.Tensor Atomic occupancies, shape (N_atoms,). Returns ------- torch.Tensor Updated electron density map. """ from torchref.base.coordinates import smallest_diff from torchref.base.electron_density import scatter_add_nd # Calculate squared distances with periodic boundary conditions diff_coords_squared = smallest_diff( surrounding_coords - xyz.unsqueeze(1), inv_frac_matrix, frac_matrix ) # Use fused Gaussian density calculation density = fused_gaussian_density(diff_coords_squared, B, b, A, occ) # Flatten and scatter to map density_flat = density.flatten() voxel_indices_flat = voxel_indices.reshape(-1, 3) # Add to map map = scatter_add_nd(density_flat, voxel_indices_flat, map) return map