torchref.base.electron_density package
Electron density map building functions.
This submodule provides functions for: - Building electron density maps from atomic models - Finding relevant voxels around atoms - Solvent mask generation - Scatter operations for map building
- torchref.base.electron_density.vectorized_add_to_map(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]
Add atoms to density map using ITC92 Gaussian parameterization.
Automatically selects the optimal implementation based on device. GPU default: Triton fused kernel (3-6x faster, falls back to JIT if Triton is unavailable). Override with TORCHREF_ATOM_PLACEMENT_GPU_MODE=jit or simple.
- Parameters:
surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).
density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).
xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).
b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).
frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).
A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).
B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).
occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).
- Returns:
Updated electron density map (modified in-place).
- Return type:
- torchref.base.electron_density.vectorized_add_to_map_aniso(surrounding_coords, voxel_indices, map, xyz, U, inv_frac_matrix, frac_matrix, A, B, occ)[source]
Add anisotropic atoms to density map using ITC92 Gaussian parameterization.
Uses the same convention as the isotropic case for consistency: - B_total = (B_itc92 + B_atomic) / 4 - rho = A × (π/B_total)^(3/2) × exp(-π² r² / B_total)
For anisotropic atoms, this generalizes to: - B_atomic_ij = 8π² × U_atomic_ij (standard crystallographic conversion) - B_total_ij = (B_itc92 × δ_ij + 8π² × U_atomic_ij) / 4 - Normalization: (π³ / det(B_total))^(1/2) - Exponent: exp(-π² × r^T × B_total^(-1) × r)
- Parameters:
surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).
map (torch.Tensor) – Electron density map of shape (nx, ny, nz).
xyz (torch.Tensor) – Atom positions in Cartesian coordinates of shape (N_atoms, 3).
U (torch.Tensor) – Anisotropic displacement parameters in Angstroms squared (u11, u22, u33, u12, u13, u23) of shape (N_atoms, 6).
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).
frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).
A (torch.Tensor) – ITC92 amplitude coefficients for each atom of shape (N_atoms, 5).
B (torch.Tensor) – ITC92 width coefficients (b parameters) in Angstroms squared for each atom of shape (N_atoms, 5).
occ (torch.Tensor) – Occupancies for each atom of shape (N_atoms,).
- Returns:
Updated electron density map.
- Return type:
- torchref.base.electron_density.scatter_add_nd(source, index, map)[source]
Vectorized n-dimensional scatter add operation.
- Parameters:
source (torch.Tensor) – Values to add to the map of shape (N,).
index (torch.Tensor) – Indices where values should be added of shape (N, ndim).
map (torch.Tensor) – N-dimensional tensor of shape (d1, d2, …, dn) to add values into.
- Returns:
Modified map with values added.
- Return type:
- torchref.base.electron_density.scatter_add_nd_super_slow(source, index, map)[source]
Non-vectorized n-dimensional scatter add operation (slow reference implementation).
- Parameters:
source (torch.Tensor) – Values to add to the map of shape (N,).
index (torch.Tensor) – Indices where values should be added of shape (N, ndim).
map (torch.Tensor) – N-dimensional tensor to add values into.
- Returns:
Modified map with values added.
- Return type:
- torchref.base.electron_density.find_relevant_voxels(real_space_grid, xyz, radius_angstrom=4, inv_frac_matrix=None)[source]
Identify surrounding voxels of atoms in a real space grid.
This is a vectorized function that finds all voxels within a spherical radius around each atom position.
- Parameters:
real_space_grid (torch.Tensor) – Real space grid containing xyz coordinates at each grid point, of shape (nx, ny, nz, 3).
xyz (torch.Tensor) – Atom coordinates in real space (Cartesian coordinates), of shape (N, 3) or (3,).
radius_angstrom (float, optional) – Radius around each atom in Angstroms. Default is 4.
inv_frac_matrix (torch.Tensor, optional) – Matrix to convert Cartesian to fractional coordinates of shape (3, 3). Required for proper handling of non-orthogonal cells.
- Returns:
surrounding_coords (torch.Tensor) – Coordinates of surrounding voxels for each atom of shape (N, R, 3), where R is the number of voxels within the radius.
voxel_indices_wrapped (torch.Tensor) – Wrapped voxel indices of shape (N, R, 3).
Notes
Atom coordinates are NOT wrapped here - periodic boundary conditions are handled in smallest_diff() which finds the minimum image distance. We only wrap voxel indices to ensure they’re valid array indices.
- torchref.base.electron_density.excise_angstrom_radius_around_coord(real_space_grid, start_indices, radius_angstrom=4.0)[source]
Identify voxel indices within an Angstrom radius around specified grid positions.
- Parameters:
real_space_grid (torch.Tensor) – Real space grid of shape (nx, ny, nz, 3) containing xyz coordinates.
start_indices (torch.Tensor) – Starting grid indices of shape (N, 3) or (3,).
radius_angstrom (float, optional) – Radius in Angstroms. Default is 4.0.
- Returns:
Wrapped voxel indices of shape (N, R, 3), where R is the number of voxels within the radius.
- Return type:
Notes
Periodic boundary conditions are handled by wrapping the indices to ensure they’re valid array indices.
- torchref.base.electron_density.add_to_solvent_mask(surrounding_coords, voxel_indices, mask, xyz, radius, inv_frac_matrix, frac_matrix)[source]
Create solvent mask by placing spheres around atom positions.
- Parameters:
surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).
mask (torch.Tensor) – Solvent mask to be updated of shape (nx, ny, nz).
xyz (torch.Tensor) – Atom positions of shape (N_atoms, 3).
radius (float) – Radius of the sphere around each atom in Angstroms.
inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).
frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).
- Returns:
Updated solvent mask as boolean tensor.
- Return type:
- torchref.base.electron_density.add_to_phenix_mask(surrounding_coords, voxel_indices, xyz, vdw_radii, solvent_radius, inv_frac_matrix, frac_matrix, grid_shape, device)[source]
Create Phenix-style three-valued mask by placing spheres around atom positions.
This is a vectorized implementation that processes all atoms and voxels at once. Creates two binary masks: - protein_mask: 1 where inside VdW radius (protein core) - boundary_mask: 1 where between VdW and VdW+solvent_radius (accessible surface)
Final three-valued mask: - 0: protein_mask == 1 (protein core) - -1: boundary_mask == 1 and protein_mask == 0 (accessible surface) - 1: both masks == 0 (bulk solvent)
- Parameters:
surrounding_coords (torch.Tensor) – Fractional coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).
voxel_indices (torch.Tensor) – Grid indices of voxels in the map of shape (N_atoms, N_voxels, 3).
xyz (torch.Tensor) – Atom positions in fractional coordinates of shape (N_atoms, 3).
vdw_radii (torch.Tensor) – VdW radius for each atom in Angstroms of shape (N_atoms,).
solvent_radius (float) – Probe radius in Angstroms (added to VdW to get accessible surface).
inv_frac_matrix (torch.Tensor) – Inverse fractional matrix for distance calculations of shape (3, 3).
frac_matrix (torch.Tensor) – Fractional matrix for distance calculations of shape (3, 3).
grid_shape (tuple) – Shape of the output mask (nx, ny, nz).
device (torch.device) – Device for tensor operations.
- Returns:
protein_mask (torch.Tensor) – Boolean mask for protein core of shape grid_shape.
boundary_mask (torch.Tensor) – Boolean mask for accessible surface of shape grid_shape.
- torchref.base.electron_density.find_solvent_voids(mask, periodic=True)[source]
Identify void regions in a 3D boolean tensor using connected component analysis.
A void is defined as a connected region of False values (solvent). With periodic boundary conditions, voids can wrap around the edges of the array (like in a crystallographic unit cell). Without periodic boundaries, only enclosed voids are detected.
- Parameters:
mask (torch.Tensor or numpy.ndarray) – Boolean tensor of shape (nx, ny, nz) where True indicates solid regions (e.g., protein) and False indicates empty regions (e.g., solvent). Can be either PyTorch tensor or NumPy array.
periodic (bool, optional) – If True, apply periodic boundary conditions (voids can wrap around edges). If False, only detect voids that are completely enclosed and don’t touch the boundaries. Default is True.
- Returns:
Dictionary where keys are int volumes (number of voxels) of each void in the original array, and values are boolean masks (torch.Tensor or numpy.ndarray) of same shape as input with True only for that specific void region. Returns an empty dict if no voids are found.
- Return type:
Examples
import torch # Create a simple 5x5x5 grid with a void in the center mask = torch.ones(5, 5, 5, dtype=torch.bool) mask[2, 2, 2] = False # Single void voxel voids = find_solvent_voids(mask) print(voids)
{1: tensor([[[False, False, …]], dtype=torch.bool)}
Notes
Uses scipy.ndimage.label for connected component analysis.
Connectivity is 26-connected (face, edge, and corner neighbors).
With periodic=True, the array is padded by wrapping to detect cross-boundary voids.
Performance is O(n) where n is the total number of voxels.
With periodic boundaries, large percolating voids are still detected.
- torchref.base.electron_density.build_electron_density(real_space_grid, xyz_iso, adp_iso, occ_iso, A_iso, B_iso, inv_frac_matrix, frac_matrix, radius_angstrom, voxel_size, xyz_aniso=None, u_aniso=None, occ_aniso=None, A_aniso=None, B_aniso=None, dtype=torch.float32)[source]
Build an electron density map from atomic parameters.
Selects the fastest available backend automatically. On CUDA, tries the fused Triton kernel first (eliminates find_relevant_voxels), then falls back to two-step Triton or JIT. On CPU, uses the JIT kernel.
- Parameters:
real_space_grid (torch.Tensor) – Coordinate grid, shape (nx, ny, nz, 3).
xyz_iso (torch.Tensor) – Isotropic atom positions, shape (n_iso, 3).
adp_iso (torch.Tensor) – Isotropic B-factors, shape (n_iso,).
occ_iso (torch.Tensor) – Isotropic occupancies, shape (n_iso,).
A_iso (torch.Tensor) – ITC92 coefficients, shape (n_iso, 5).
B_iso (torch.Tensor) – ITC92 coefficients, shape (n_iso, 5).
inv_frac_matrix (torch.Tensor) – Cartesian-to-fractional matrix, shape (3, 3).
frac_matrix (torch.Tensor) – Fractional-to-Cartesian matrix, shape (3, 3).
radius_angstrom (float) – Radius around each atom in Angstroms.
voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).
xyz_aniso (torch.Tensor, optional) – Anisotropic atom positions, shape (n_aniso, 3).
u_aniso (torch.Tensor, optional) – Anisotropic U parameters, shape (n_aniso, 6).
occ_aniso (torch.Tensor, optional) – Anisotropic occupancies, shape (n_aniso,).
A_aniso (torch.Tensor, optional) – ITC92 coefficients for anisotropic atoms, shape (n_aniso, 5).
B_aniso (torch.Tensor, optional) – ITC92 coefficients for anisotropic atoms, shape (n_aniso, 5).
dtype (torch.dtype, optional) – Float dtype for the density map. Default torch.float32.
- Returns:
Electron density map, shape (nx, ny, nz).
- Return type: