Source code for torchref.base.electron_density.map_building

"""
Electron density map building functions.

Functions for adding atomic contributions to electron density maps
using ITC92 Gaussian parameterization.
"""

import numpy as np
import torch

from torchref.base.coordinates.periodic_boundary import (
    smallest_diff,
    smallest_diff_aniso,
)


[docs] def scatter_add_nd_super_slow(source, index, map): """ Non-vectorized n-dimensional scatter add operation (slow reference implementation). Parameters ---------- source : torch.Tensor Values to add to the map of shape (N,). index : torch.Tensor Indices where values should be added of shape (N, ndim). map : torch.Tensor N-dimensional tensor to add values into. Returns ------- torch.Tensor Modified map with values added. """ for i in range(source.shape[0]): idx = tuple(index[i].tolist()) map[idx] += source[i] return map
[docs] def scatter_add_nd(source, index, map): """ Vectorized n-dimensional scatter add operation. Parameters ---------- source : torch.Tensor Values to add to the map of shape (N,). index : torch.Tensor Indices where values should be added of shape (N, ndim). map : torch.Tensor N-dimensional tensor of shape (d1, d2, ..., dn) to add values into. Returns ------- torch.Tensor Modified map with values added. """ map_shape = torch.tensor(map.shape, device=index.device, dtype=torch.int64) # Convert n-dimensional indices to flat indices # For shape (d1, d2, d3, ..., dn), flat_index = i0 * (d1*d2*...*dn) + i1 * (d2*d3*...*dn) + ... + in strides = torch.ones(len(map_shape), device=index.device, dtype=torch.int64) for i in range(len(map_shape) - 2, -1, -1): strides[i] = strides[i + 1] * map_shape[i + 1] index_flat = torch.sum(index * strides.unsqueeze(0), dim=-1) map_flat = map.view(-1) try: map_flat.scatter_add_(0, index_flat, source) except RuntimeError as e: print("Error during scatter_add_: ", e) print( "Source shape: ", source.shape, "device: ", source.device, "dtype: ", source.dtype, ) print( "Index shape: ", index.shape, "device: ", index.device, "dtype: ", index.dtype, ) print("Map shape: ", map.shape, "device: ", map.device, "dtype: ", map.dtype) raise e return map
[docs] def vectorized_add_to_map( surrounding_coords, voxel_indices, map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ, ): """ Add atoms to density map using ITC92 Gaussian parameterization. Parameters ---------- surrounding_coords : torch.Tensor Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3). voxel_indices : torch.Tensor Indices of voxels in the map of shape (N_atoms, N_voxels, 3). map : torch.Tensor Electron density map of shape (nx, ny, nz). xyz : torch.Tensor Atom positions of shape (N_atoms, 3). b : torch.Tensor B-factors (thermal parameters) in Angstroms squared of shape (N_atoms,). inv_frac_matrix : torch.Tensor Inverse fractionalization matrix of shape (3, 3). frac_matrix : torch.Tensor Fractionalization matrix of shape (3, 3). A : torch.Tensor ITC92 amplitude coefficients for each atom of shape (N_atoms, 5). B : torch.Tensor ITC92 width coefficients (b parameters) in Angstroms squared for each atom of shape (N_atoms, 5). occ : torch.Tensor Occupancies for each atom of shape (N_atoms,). Returns ------- torch.Tensor Updated electron density map. """ # Calculate squared distances with periodic boundary conditions # diff_coords shape: (N_atoms, N_voxels) diff_coords_squared = smallest_diff( surrounding_coords - xyz.unsqueeze(1), inv_frac_matrix, frac_matrix ) B_total = ((B + b.unsqueeze(1)) / 4).clamp(min=1e-1) # Normalization constant: (π/B_total)^(3/2) normalization = (np.pi / B_total) ** 1.5 # Scale amplitudes by occupancy and normalization A_normalized = A * occ.unsqueeze(1) * normalization # Calculate Gaussian with exponent: exp(-π²r²/B_total) # Note: diff_coords_squared already contains r² gaussian_terms = torch.exp( -(np.pi**2) * diff_coords_squared.unsqueeze(2) / B_total.unsqueeze(1) ) # Sum over the 4 Gaussian components density = torch.sum(A_normalized.unsqueeze(1) * gaussian_terms, dim=2) # Flatten to (N_atoms * N_voxels,) density_flat = density.flatten() voxel_indices_flat = voxel_indices.reshape(-1, 3) # Add to map map = scatter_add_nd(density_flat, voxel_indices_flat, map) return map
[docs] def vectorized_add_to_map_aniso( surrounding_coords, voxel_indices, map, xyz, U, inv_frac_matrix, frac_matrix, A, B, occ, ): """ Add anisotropic atoms to density map using ITC92 Gaussian parameterization. Uses the same convention as the isotropic case for consistency: - B_total = (B_itc92 + B_atomic) / 4 - rho = A × (π/B_total)^(3/2) × exp(-π² r² / B_total) For anisotropic atoms, this generalizes to: - B_atomic_ij = 8π² × U_atomic_ij (standard crystallographic conversion) - B_total_ij = (B_itc92 × δ_ij + 8π² × U_atomic_ij) / 4 - Normalization: (π³ / det(B_total))^(1/2) - Exponent: exp(-π² × r^T × B_total^(-1) × r) Parameters ---------- surrounding_coords : torch.Tensor Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3). voxel_indices : torch.Tensor Indices of voxels in the map of shape (N_atoms, N_voxels, 3). map : torch.Tensor Electron density map of shape (nx, ny, nz). xyz : torch.Tensor Atom positions in Cartesian coordinates of shape (N_atoms, 3). U : torch.Tensor Anisotropic displacement parameters in Angstroms squared (u11, u22, u33, u12, u13, u23) of shape (N_atoms, 6). inv_frac_matrix : torch.Tensor Inverse fractionalization matrix of shape (3, 3). frac_matrix : torch.Tensor Fractionalization matrix of shape (3, 3). A : torch.Tensor ITC92 amplitude coefficients for each atom of shape (N_atoms, 5). B : torch.Tensor ITC92 width coefficients (b parameters) in Angstroms squared for each atom of shape (N_atoms, 5). occ : torch.Tensor Occupancies for each atom of shape (N_atoms,). Returns ------- torch.Tensor Updated electron density map. """ # Calculate distance vectors with periodic boundary conditions diff_coords = surrounding_coords - xyz.unsqueeze(1) diff_coords = smallest_diff_aniso(diff_coords, inv_frac_matrix, frac_matrix) n_atoms = B.shape[0] n_gauss = B.shape[1] # Convert atomic U to B using standard crystallographic convention: B = 8π² U eight_pi_sq = 8 * np.pi**2 U_diag = U[:, :3] # u11, u22, u33 U_off_diag = U[:, 3:] # u12, u13, u23 # Convert U to B B_atomic_diag = eight_pi_sq * U_diag # (N_atoms, 3) B_atomic_off = eight_pi_sq * U_off_diag # (N_atoms, 3) # Compute B_total = (B_itc92 + B_atomic) / 4 for each ITC92 component # B_itc92 is isotropic, so it only adds to diagonal B_expanded = B.unsqueeze(2) # (N_atoms, n_gauss, 1) B_atomic_diag_expanded = B_atomic_diag.unsqueeze(1) # (N_atoms, 1, 3) B_total_diag = (B_expanded + B_atomic_diag_expanded) / 4 # (N_atoms, n_gauss, 3) B_total_diag = B_total_diag.clamp(min=0.1) # Clamp for numerical stability # Off-diagonal B_total (ITC92 doesn't contribute to off-diagonals) B_total_off = B_atomic_off.unsqueeze(1).expand(-1, n_gauss, -1) / 4 # Build the B_total matrix for each atom and Gaussian component B_matrix = torch.zeros(n_atoms, n_gauss, 3, 3, device=B.device, dtype=B.dtype) B_matrix[:, :, 0, 0] = B_total_diag[:, :, 0] B_matrix[:, :, 1, 1] = B_total_diag[:, :, 1] B_matrix[:, :, 2, 2] = B_total_diag[:, :, 2] B_matrix[:, :, 0, 1] = B_total_off[:, :, 0] B_matrix[:, :, 1, 0] = B_total_off[:, :, 0] B_matrix[:, :, 0, 2] = B_total_off[:, :, 1] B_matrix[:, :, 2, 0] = B_total_off[:, :, 1] B_matrix[:, :, 1, 2] = B_total_off[:, :, 2] B_matrix[:, :, 2, 1] = B_total_off[:, :, 2] # Compute inverse of B_total matrix for the exponent B_inv = torch.linalg.inv(B_matrix) # (N_atoms, n_gauss, 3, 3) # Compute determinant for normalization det_B = torch.linalg.det(B_matrix) # (N_atoms, n_gauss) det_B = det_B.clamp(min=1e-10) # Normalization: (π³ / det(B))^(1/2) - generalizes (π/B)^(3/2) to 3D normalization = (np.pi**3 / det_B) ** 0.5 # (N_atoms, n_gauss) # Scale amplitudes by occupancy and normalization A_normalized = A * occ.unsqueeze(1) * normalization # (N_atoms, n_gauss) # Compute quadratic form: r^T × B^(-1) × r for each Gaussian component diff_coords_expanded = diff_coords.unsqueeze(2) # (N_atoms, N_voxels, 1, 3) B_inv_expanded = B_inv.unsqueeze(1) # (N_atoms, 1, n_gauss, 3, 3) # First: B^(-1) × r -> (N_atoms, N_voxels, n_gauss, 3) Binv_times_r = torch.einsum("naijk,namk->naij", B_inv_expanded, diff_coords_expanded) # Second: r^T × (B^(-1) × r) -> (N_atoms, N_voxels, n_gauss) quad_form = torch.einsum("namk,namk->nam", diff_coords_expanded, Binv_times_r) # Calculate Gaussian density: exp(-π² × r^T × B^(-1) × r) gaussian_terms = torch.exp(-np.pi**2 * quad_form) # (N_atoms, N_voxels, n_gauss) # Sum over Gaussian components density = torch.sum( A_normalized.unsqueeze(1) * gaussian_terms, dim=2 ) # (N_atoms, N_voxels) # Flatten and add to map density_flat = density.flatten() voxel_indices_flat = voxel_indices.reshape(-1, 3) map = scatter_add_nd(density_flat, voxel_indices_flat, map) return map