torchref.base.electron_density.solvent_mask module

Solvent mask generation functions.

Functions for creating solvent masks and identifying void regions in crystallographic unit cells.

torchref.base.electron_density.solvent_mask.add_to_solvent_mask(surrounding_coords, voxel_indices, mask, xyz, radius, inv_frac_matrix, frac_matrix)[source]

Create solvent mask by placing spheres around atom positions.

Parameters:
  • surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • mask (torch.Tensor) – Solvent mask to be updated of shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions of shape (N_atoms, 3).

  • radius (float) – Radius of the sphere around each atom in Angstroms.

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Updated solvent mask as boolean tensor.

Return type:

torch.Tensor

torchref.base.electron_density.solvent_mask.add_to_phenix_mask(surrounding_coords, voxel_indices, xyz, vdw_radii, solvent_radius, inv_frac_matrix, frac_matrix, grid_shape, device)[source]

Create Phenix-style three-valued mask by placing spheres around atom positions.

This is a vectorized implementation that processes all atoms and voxels at once. Creates two binary masks: - protein_mask: 1 where inside VdW radius (protein core) - boundary_mask: 1 where between VdW and VdW+solvent_radius (accessible surface)

Final three-valued mask: - 0: protein_mask == 1 (protein core) - -1: boundary_mask == 1 and protein_mask == 0 (accessible surface) - 1: both masks == 0 (bulk solvent)

Parameters:
  • surrounding_coords (torch.Tensor) – Fractional coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Grid indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • xyz (torch.Tensor) – Atom positions in fractional coordinates of shape (N_atoms, 3).

  • vdw_radii (torch.Tensor) – VdW radius for each atom in Angstroms of shape (N_atoms,).

  • solvent_radius (float) – Probe radius in Angstroms (added to VdW to get accessible surface).

  • inv_frac_matrix (torch.Tensor) – Inverse fractional matrix for distance calculations of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractional matrix for distance calculations of shape (3, 3).

  • grid_shape (tuple) – Shape of the output mask (nx, ny, nz).

  • device (torch.device) – Device for tensor operations.

Returns:

  • protein_mask (torch.Tensor) – Boolean mask for protein core of shape grid_shape.

  • boundary_mask (torch.Tensor) – Boolean mask for accessible surface of shape grid_shape.

torchref.base.electron_density.solvent_mask.find_solvent_voids(mask, periodic=True)[source]

Identify void regions in a 3D boolean tensor using connected component analysis.

A void is defined as a connected region of False values (solvent). With periodic boundary conditions, voids can wrap around the edges of the array (like in a crystallographic unit cell). Without periodic boundaries, only enclosed voids are detected.

Parameters:
  • mask (torch.Tensor or numpy.ndarray) – Boolean tensor of shape (nx, ny, nz) where True indicates solid regions (e.g., protein) and False indicates empty regions (e.g., solvent). Can be either PyTorch tensor or NumPy array.

  • periodic (bool, optional) – If True, apply periodic boundary conditions (voids can wrap around edges). If False, only detect voids that are completely enclosed and don’t touch the boundaries. Default is True.

Returns:

Dictionary where keys are int volumes (number of voxels) of each void in the original array, and values are boolean masks (torch.Tensor or numpy.ndarray) of same shape as input with True only for that specific void region. Returns an empty dict if no voids are found.

Return type:

dict

Examples

import torch
# Create a simple 5x5x5 grid with a void in the center
mask = torch.ones(5, 5, 5, dtype=torch.bool)
mask[2, 2, 2] = False  # Single void voxel
voids = find_solvent_voids(mask)
print(voids)

{1: tensor([[[False, False, …]], dtype=torch.bool)}

Notes

  • Uses scipy.ndimage.label for connected component analysis.

  • Connectivity is 26-connected (face, edge, and corner neighbors).

  • With periodic=True, the array is padded by wrapping to detect cross-boundary voids.

  • Performance is O(n) where n is the total number of voxels.

  • With periodic boundaries, large percolating voids are still detected.