Source code for torchref.symmetry.map_symmetry

"""
Map-level symmetry operations for electron density maps.

This module provides efficient symmetry operations applied directly to density maps,
which is much faster than applying symmetry to individual atoms.

This module uses a factory pattern: calling MapSymmetry() will automatically
return either MapSymmetryDirect (fast, no interpolation) or MapSymmetryInterpolation
(fallback with interpolation) depending on grid compatibility.

Space groups can be specified as strings, integers (1-230), or gemmi.SpaceGroup objects.
"""

import torch
import torch.nn as nn

from torchref.config import get_default_device, get_float_dtype
from torchref.symmetry.spacegroup import SpaceGroup, SpaceGroupLike
from torchref.utils.device_mixin import DeviceMixin


[docs] def MapSymmetry( space_group: SpaceGroupLike, map_shape, cell_params, dtype_float=get_float_dtype(), verbose=1, device=get_default_device(), ): """ Factory function to create the appropriate MapSymmetry implementation. This function checks if the grid size is compatible with direct indexing (no interpolation needed). If compatible, returns MapSymmetryDirect for maximum performance. Otherwise, returns MapSymmetryInterpolation as fallback. Parameters ---------- space_group : str, int, or gemmi.SpaceGroup Space group specification (e.g., 'P21', 4, gemmi.SpaceGroup('P 21')). map_shape : tuple of int Shape of the density map (nx, ny, nz). cell_params : torch.Tensor, shape (6,) Unit cell parameters [a, b, c, alpha, beta, gamma] in Å and degrees. dtype_float : torch.dtype, default torch.float32 Floating point precision to use. verbose : int, default 1 Verbosity level (0=silent, 1=info, 2=debug). device : torch.device, default: configured device.current Device to use for computation. Returns ------- MapSymmetryDirect or MapSymmetryInterpolation The appropriate implementation based on grid compatibility. """ # Check grid compatibility symmetry = SpaceGroup(space_group, dtype=dtype_float, device=device) compat = symmetry.check_grid_compatibility(map_shape) if compat["can_use_direct_indexing"]: # Use fast direct indexing implementation if verbose > 0: print( f"MapSymmetry: Using direct indexing (no interpolation) for {space_group}" ) return MapSymmetryDirect( space_group, map_shape, cell_params, dtype_float, verbose, device ) else: # Use interpolation-based fallback if verbose > 0: print("MapSymmetry: Grid not compatible with direct indexing") print(f" Using interpolation-based fallback for {space_group}") if compat["issues"]: for issue in compat["issues"]: print(f" - {issue}") suggested = symmetry.suggest_grid_size(map_shape, make_fft_friendly=True) print(f" Suggested grid for direct indexing: {suggested}") # Import and return interpolation version from torchref.symmetry.map_symmetry_interpolation import ( MapSymmetry as MapSymmetryInterpolation, ) return MapSymmetryInterpolation( space_group, map_shape, cell_params, dtype_float, verbose, device )
[docs] class MapSymmetryDirect(DeviceMixin, nn.Module): """ Fast direct-indexing implementation of crystallographic symmetry operations. Computes symmetry mates one operation at a time (streaming) so that memory usage is O(grid) regardless of the number of symmetry operations, rather than O(n_ops * grid) for storing precomputed index grids. NOTE: Do not instantiate this class directly. Use the MapSymmetry() factory function instead, which will automatically select the appropriate implementation. """
[docs] def __init__( self, space_group, map_shape, cell_params, dtype_float=get_float_dtype(), verbose=1, device=get_default_device(), ): super().__init__() self.dtype_float = dtype_float self.space_group = space_group self.map_shape = tuple(map_shape) self.cell_params = cell_params self.verbose = verbose self.device = device self.symmetry = SpaceGroup( space_group, dtype=self.dtype_float, device=self.device ) self.n_ops = self.symmetry.matrices.shape[0] self.can_use_direct_indexing = True if self.verbose > 0: print(f"MapSymmetryDirect initialized for {space_group}") print(f" Number of symmetry operations: {self.n_ops}") print(f" Map shape: {self.map_shape}")
# ------------------------------------------------------------------ # Core: compute index grid for a single symmetry operation # ------------------------------------------------------------------ def _compute_index_grid(self, op_index: int) -> torch.Tensor: """Compute the integer index grid for one symmetry operation. Returns shape (nx, ny, nz, 3) int64. The result is temporary and not stored as a buffer. """ nx, ny, nz = self.map_shape device = self.symmetry.matrices.device fx = torch.arange(nx, dtype=self.dtype_float, device=device) / nx fy = torch.arange(ny, dtype=self.dtype_float, device=device) / ny fz = torch.arange(nz, dtype=self.dtype_float, device=device) / nz gx, gy, gz = torch.meshgrid(fx, fy, fz, indexing="ij") grid_flat = torch.stack([gx, gy, gz], dim=-1).reshape(-1, 3) transformed = torch.matmul(self.symmetry.matrices[op_index], grid_flat.T).T transformed = transformed + self.symmetry.translations[op_index] transformed = transformed - torch.floor(transformed) shape_t = torch.tensor([nx, ny, nz], dtype=self.dtype_float, device=device) indices = torch.round(transformed * shape_t).to(torch.int64) indices[:, 0] %= nx indices[:, 1] %= ny indices[:, 2] %= nz return indices.reshape(nx, ny, nz, 3) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def get_symmetry_mate(self, density_map, operation_index): """Apply a single symmetry operation via direct indexing.""" if operation_index < 0 or operation_index >= self.n_ops: raise ValueError( f"Operation index {operation_index} out of range [0, {self.n_ops-1}]" ) if density_map.shape != self.map_shape: raise ValueError( f"Map shape {density_map.shape} doesn't match expected {self.map_shape}" ) ig = self._compute_index_grid(operation_index) return density_map[ig[..., 0], ig[..., 1], ig[..., 2]]
[docs] def forward(self, density_map, apply_symmetry=True, combine_mode="sum"): """Apply symmetry operations to density map. Computes one symmetry mate at a time and accumulates into the result, so peak memory is only 1 index grid + 2 density maps regardless of the number of symmetry operations. """ if not apply_symmetry or self.n_ops == 1: return density_map ig = self._compute_index_grid(0) if combine_mode == "sum": result = density_map[ig[..., 0], ig[..., 1], ig[..., 2]] for i in range(1, self.n_ops): ig = self._compute_index_grid(i) result = result + density_map[ig[..., 0], ig[..., 1], ig[..., 2]] elif combine_mode == "max": result = density_map[ig[..., 0], ig[..., 1], ig[..., 2]] for i in range(1, self.n_ops): ig = self._compute_index_grid(i) result = torch.max( result, density_map[ig[..., 0], ig[..., 1], ig[..., 2]] ) else: raise ValueError( f"Unknown combine_mode: {combine_mode}. Use 'sum' or 'max'." ) return result
[docs] def __call__(self, density_map, apply_symmetry=True, combine_mode="sum"): """Make the class callable like a PyTorch module.""" return self.forward( density_map, apply_symmetry=apply_symmetry, combine_mode=combine_mode )
[docs] def get_symmetry_info(self): """Get information about symmetry operations.""" return { "space_group": self.space_group, "n_operations": self.n_ops, "matrices": self.symmetry.matrices, "translations": self.symmetry.translations, }
def __repr__(self): return ( f"MapSymmetryDirect(space_group='{self.space_group}', " f"n_ops={self.n_ops}, map_shape={self.map_shape})" )