Source code for torchref.base.electron_density.voxel_utils

"""
Voxel utility functions for electron density calculations.

Functions for finding relevant voxels around atoms for map building.
"""

import torch

from torchref.config import dtypes


[docs] def find_relevant_voxels(real_space_grid, xyz, radius_angstrom=4, inv_frac_matrix=None): """ Identify surrounding voxels of atoms in a real space grid. This is a vectorized function that finds all voxels within a spherical radius around each atom position. Parameters ---------- real_space_grid : torch.Tensor Real space grid containing xyz coordinates at each grid point, of shape (nx, ny, nz, 3). xyz : torch.Tensor Atom coordinates in real space (Cartesian coordinates), of shape (N, 3) or (3,). radius_angstrom : float, optional Radius around each atom in Angstroms. Default is 4. inv_frac_matrix : torch.Tensor, optional Matrix to convert Cartesian to fractional coordinates of shape (3, 3). Required for proper handling of non-orthogonal cells. Returns ------- surrounding_coords : torch.Tensor Coordinates of surrounding voxels for each atom of shape (N, R, 3), where R is the number of voxels within the radius. voxel_indices_wrapped : torch.Tensor Wrapped voxel indices of shape (N, R, 3). Notes ----- Atom coordinates are NOT wrapped here - periodic boundary conditions are handled in smallest_diff() which finds the minimum image distance. We only wrap voxel indices to ensure they're valid array indices. """ # Ensure xyz is 2D (N, 3) if xyz.ndim == 1: xyz = xyz.unsqueeze(0) grid_shape = torch.tensor(real_space_grid.shape[:3], device=xyz.device) # Get grid origin (first voxel corner) grid_origin = real_space_grid[0, 0, 0] # Convert atom positions to grid indices # For non-orthogonal cells, we must use fractional coordinates if inv_frac_matrix is not None: # Proper way: Cartesian -> Fractional -> Wrap to [0,1] -> Grid indices # This ensures atoms outside the unit cell are correctly wrapped xyz_frac = torch.matmul(inv_frac_matrix, xyz.T).T # (N, 3) xyz_frac = xyz_frac % 1.0 # Wrap to [0, 1] center_idx = torch.round(xyz_frac * grid_shape.unsqueeze(0)).to(torch.int64) else: # Fallback for orthogonal cells (less accurate for non-orthogonal) voxelsize = real_space_grid[3, 3, 3] - real_space_grid[2, 2, 2] center_idx = torch.round( (xyz - grid_origin.unsqueeze(0)) / voxelsize.unsqueeze(0) ).to(torch.int64) voxel_indices_wrapped = excise_angstrom_radius_around_coord( real_space_grid, center_idx, radius_angstrom ) # Extract coordinates from real_space_grid # For each atom, get all surrounding voxel coordinates surrounding_coords = real_space_grid[ voxel_indices_wrapped[..., 0], voxel_indices_wrapped[..., 1], voxel_indices_wrapped[..., 2], ] return surrounding_coords, voxel_indices_wrapped
[docs] def excise_angstrom_radius_around_coord( real_space_grid, start_indices, radius_angstrom=4.0 ): """ Identify voxel indices within an Angstrom radius around specified grid positions. Parameters ---------- real_space_grid : torch.Tensor Real space grid of shape (nx, ny, nz, 3) containing xyz coordinates. start_indices : torch.Tensor Starting grid indices of shape (N, 3) or (3,). radius_angstrom : float, optional Radius in Angstroms. Default is 4.0. Returns ------- torch.Tensor Wrapped voxel indices of shape (N, R, 3), where R is the number of voxels within the radius. Notes ----- Periodic boundary conditions are handled by wrapping the indices to ensure they're valid array indices. """ # Ensure xyz is 2D (N, 3) if start_indices.ndim == 1: start_indices = start_indices.unsqueeze(0) grid_shape = torch.tensor(real_space_grid.shape[:3], device=start_indices.device) # Get grid origin (first voxel corner) voxelsize = real_space_grid[3, 3, 3] - real_space_grid[2, 2, 2] min_box_radius = torch.ceil(radius_angstrom / torch.min(voxelsize)).to(dtypes.int) gridx = torch.arange( -min_box_radius, min_box_radius + 1, device=start_indices.device ) gridy = torch.arange( -min_box_radius, min_box_radius + 1, device=start_indices.device ) gridz = torch.arange( -min_box_radius, min_box_radius + 1, device=start_indices.device ) x, y, z = torch.meshgrid(gridx, gridy, gridz, indexing="ij") coords = torch.stack((x, y, z), dim=-1) distance_map = torch.sqrt( torch.sum((coords * voxelsize.unsqueeze(0)) ** 2, axis=-1) ) within_radius_mask = distance_map <= radius_angstrom local_offsets = coords[within_radius_mask] # Shape: (N_voxels_within_radius, 3) voxel_indices = local_offsets.unsqueeze(0) + start_indices.unsqueeze(1) voxel_indices_wrapped = voxel_indices % grid_shape.unsqueeze(0).unsqueeze(0) return voxel_indices_wrapped