Source code for torchref.base.electron_density.solvent_mask

"""
Solvent mask generation functions.

Functions for creating solvent masks and identifying void regions
in crystallographic unit cells.
"""

import numpy as np
import torch

from torchref.config import dtypes
from torchref.base.coordinates.periodic_boundary import smallest_diff
from .map_building import scatter_add_nd


[docs] def add_to_solvent_mask( surrounding_coords, voxel_indices, mask, xyz, radius, inv_frac_matrix, frac_matrix ): """ Create solvent mask by placing spheres around atom positions. Parameters ---------- surrounding_coords : torch.Tensor Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3). voxel_indices : torch.Tensor Indices of voxels in the map of shape (N_atoms, N_voxels, 3). mask : torch.Tensor Solvent mask to be updated of shape (nx, ny, nz). xyz : torch.Tensor Atom positions of shape (N_atoms, 3). radius : float Radius of the sphere around each atom in Angstroms. inv_frac_matrix : torch.Tensor Inverse fractionalization matrix of shape (3, 3). frac_matrix : torch.Tensor Fractionalization matrix of shape (3, 3). Returns ------- torch.Tensor Updated solvent mask as boolean tensor. """ mask = mask.to(dtype=dtypes.int) # Calculate squared distances with periodic boundary conditions diff_coords_squared = smallest_diff( surrounding_coords - xyz.unsqueeze(1), inv_frac_matrix, frac_matrix ) # Create boolean mask where distance squared is less than radius squared within_sphere = diff_coords_squared <= radius**2 # (N_atoms, N_voxels) # Convert boolean to float for addition values_to_add = within_sphere.to(dtype=mask.dtype).flatten() voxel_indices_flat = voxel_indices.reshape(-1, 3).to(dtypes.int) # Add to mask mask = scatter_add_nd(values_to_add, voxel_indices_flat, mask) # Ensure mask is binary (0 or 1) mask = torch.clamp(mask, max=1.0) return mask.to(torch.bool)
[docs] def add_to_phenix_mask( surrounding_coords, voxel_indices, xyz, vdw_radii, solvent_radius, inv_frac_matrix, frac_matrix, grid_shape, device, ): """ Create Phenix-style three-valued mask by placing spheres around atom positions. This is a vectorized implementation that processes all atoms and voxels at once. Creates two binary masks: - protein_mask: 1 where inside VdW radius (protein core) - boundary_mask: 1 where between VdW and VdW+solvent_radius (accessible surface) Final three-valued mask: - 0: protein_mask == 1 (protein core) - -1: boundary_mask == 1 and protein_mask == 0 (accessible surface) - 1: both masks == 0 (bulk solvent) Parameters ---------- surrounding_coords : torch.Tensor Fractional coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3). voxel_indices : torch.Tensor Grid indices of voxels in the map of shape (N_atoms, N_voxels, 3). xyz : torch.Tensor Atom positions in fractional coordinates of shape (N_atoms, 3). vdw_radii : torch.Tensor VdW radius for each atom in Angstroms of shape (N_atoms,). solvent_radius : float Probe radius in Angstroms (added to VdW to get accessible surface). inv_frac_matrix : torch.Tensor Inverse fractional matrix for distance calculations of shape (3, 3). frac_matrix : torch.Tensor Fractional matrix for distance calculations of shape (3, 3). grid_shape : tuple Shape of the output mask (nx, ny, nz). device : torch.device Device for tensor operations. Returns ------- protein_mask : torch.Tensor Boolean mask for protein core of shape grid_shape. boundary_mask : torch.Tensor Boolean mask for accessible surface of shape grid_shape. """ # Calculate distances for all atom-voxel pairs diff = surrounding_coords - xyz.unsqueeze(1) diff_coords_squared = smallest_diff(diff, inv_frac_matrix, frac_matrix) distances = torch.sqrt(diff_coords_squared) # (N_atoms, N_voxels) # Expand VdW radii for broadcasting vdw_radii_expanded = vdw_radii.unsqueeze(1) # (N_atoms, 1) r_cutoff = vdw_radii_expanded + solvent_radius # (N_atoms, 1) # Create two binary classifications in_protein_core = distances < vdw_radii_expanded # (N_atoms, N_voxels) in_accessible_surface = (distances >= vdw_radii_expanded) & ( distances < r_cutoff ) # (N_atoms, N_voxels) # Flatten for scatter operations voxel_indices_flat = voxel_indices.reshape(-1, 3).to(torch.long) # Create protein core mask using scatter_add int_dtype = dtypes.int protein_mask = torch.zeros(grid_shape, dtype=int_dtype, device=device) protein_values = in_protein_core.flatten().to(dtype=int_dtype) protein_mask = scatter_add_nd(protein_values, voxel_indices_flat, protein_mask) protein_mask = protein_mask > 0 # Convert to binary: True where protein core # Create accessible surface (boundary) mask using scatter_add boundary_mask = torch.zeros(grid_shape, dtype=int_dtype, device=device) boundary_values = in_accessible_surface.flatten().to(dtype=int_dtype) boundary_mask = scatter_add_nd(boundary_values, voxel_indices_flat, boundary_mask) boundary_mask = ( boundary_mask > 0 ) # Convert to binary: True where accessible surface return protein_mask, boundary_mask
[docs] def find_solvent_voids(mask, periodic=True): """ Identify void regions in a 3D boolean tensor using connected component analysis. A void is defined as a connected region of False values (solvent). With periodic boundary conditions, voids can wrap around the edges of the array (like in a crystallographic unit cell). Without periodic boundaries, only enclosed voids are detected. Parameters ---------- mask : torch.Tensor or numpy.ndarray Boolean tensor of shape (nx, ny, nz) where True indicates solid regions (e.g., protein) and False indicates empty regions (e.g., solvent). Can be either PyTorch tensor or NumPy array. periodic : bool, optional If True, apply periodic boundary conditions (voids can wrap around edges). If False, only detect voids that are completely enclosed and don't touch the boundaries. Default is True. Returns ------- dict Dictionary where keys are int volumes (number of voxels) of each void in the original array, and values are boolean masks (torch.Tensor or numpy.ndarray) of same shape as input with True only for that specific void region. Returns an empty dict if no voids are found. Examples -------- :: import torch # Create a simple 5x5x5 grid with a void in the center mask = torch.ones(5, 5, 5, dtype=torch.bool) mask[2, 2, 2] = False # Single void voxel voids = find_solvent_voids(mask) print(voids) {1: tensor([[[False, False, ...]], dtype=torch.bool)} Notes ----- - Uses scipy.ndimage.label for connected component analysis. - Connectivity is 26-connected (face, edge, and corner neighbors). - With periodic=True, the array is padded by wrapping to detect cross-boundary voids. - Performance is O(n) where n is the total number of voxels. - With periodic boundaries, large percolating voids are still detected. """ from scipy import ndimage # Check if input is torch tensor or numpy array is_torch = isinstance(mask, torch.Tensor) if is_torch: # Convert to numpy for scipy processing device = mask.device dtype = mask.dtype mask_np = mask.cpu().numpy() else: mask_np = np.asarray(mask) dtype = mask.dtype # Get original shape original_shape = mask_np.shape nx, ny, nz = original_shape # Invert mask: we want to label the False regions (voids) inverted_mask = ~mask_np if periodic: # Apply periodic boundary conditions by padding with wrapped values # Pad by 1 on each side to allow connections across boundaries padded_mask = np.pad(inverted_mask, pad_width=1, mode="wrap") # Label connected components using 26-connectivity structure = ndimage.generate_binary_structure(3, 3) # 26-connectivity labeled_array, num_features = ndimage.label(padded_mask, structure=structure) # Extract the central region (original array size) # The padding helps connect voids across boundaries labeled_central = labeled_array[1:-1, 1:-1, 1:-1] # Now we need to identify which labels in the central region correspond to # the same void (they might have different labels in the padded array) # Create a mapping from labels in the central region to unique void IDs # For periodic boundaries, voids can wrap around # Check all 6 boundary pairs for potential wrapping label_equivalences = {} def add_equivalence(label1, label2): """Track that two labels are the same void.""" if label1 == 0 or label2 == 0: # Ignore background return if label1 == label2: return # Find root labels while label1 in label_equivalences: label1 = label_equivalences[label1] while label2 in label_equivalences: label2 = label_equivalences[label2] if label1 != label2: # Make label1 point to label2 label_equivalences[label1] = label2 # Check x-boundaries (x=0 and x=max-1) for j in range(ny): for k in range(nz): add_equivalence(labeled_central[0, j, k], labeled_central[-1, j, k]) # Check y-boundaries (y=0 and y=max-1) for i in range(nx): for k in range(nz): add_equivalence(labeled_central[i, 0, k], labeled_central[i, -1, k]) # Check z-boundaries (z=0 and z=max-1) for i in range(nx): for j in range(ny): add_equivalence(labeled_central[i, j, 0], labeled_central[i, j, -1]) # Create final label mapping def find_root(label): """Find the root label for equivalence class.""" if label == 0: return 0 root = label while root in label_equivalences: root = label_equivalences[root] return root # Relabel the array with equivalence classes final_labeled = np.zeros_like(labeled_central) unique_labels = {} next_label = 1 for idx in np.ndindex(labeled_central.shape): label = labeled_central[idx] if label == 0: continue root = find_root(label) if root not in unique_labels: unique_labels[root] = next_label next_label += 1 final_labeled[idx] = unique_labels[root] labeled_array = final_labeled num_features = len(unique_labels) else: # Non-periodic: standard connected component labeling structure = ndimage.generate_binary_structure(3, 3) # 26-connectivity labeled_array, num_features = ndimage.label(inverted_mask, structure=structure) # If no features found, return empty dict if num_features == 0: return {} # Create dictionary to store voids voids_dict = {} # Process each labeled region for label_id in range(1, num_features + 1): # Create mask for this specific void void_mask = labeled_array == label_id if not periodic: # For non-periodic, check if void touches boundary touches_boundary = ( np.any(void_mask[0, :, :]) or np.any(void_mask[-1, :, :]) or np.any(void_mask[:, 0, :]) or np.any(void_mask[:, -1, :]) or np.any(void_mask[:, :, 0]) or np.any(void_mask[:, :, -1]) ) # Only include voids that don't touch the boundary if touches_boundary: continue # Calculate volume (number of voxels) volume = int(np.sum(void_mask)) if volume == 0: continue # Convert back to torch tensor if input was torch if is_torch: void_mask_tensor = torch.from_numpy(void_mask).to( device=device, dtype=torch.bool ) else: void_mask_tensor = void_mask # Store in dictionary # If multiple voids have the same volume, append a counter key = volume counter = 1 original_key = key while key in voids_dict: key = f"{original_key}_{counter}" counter += 1 voids_dict[key] = void_mask_tensor return voids_dict