Source code for torchref.model.sf_fft

"""
SfFFT - Structure Factor calculation via FFT (Fast Fourier Transform).

This module provides a PyTorch nn.Module that handles:
- Grid setup for real-space electron density calculations
- Building electron density maps from atomic parameters
- FFT-based conversion to structure factors

The SfFFT class can be used standalone (stateless) or with stored grid state
(stateful), making it flexible for various use cases.

Note: FFT is provided as a backward compatibility alias for SfFFT.
"""

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from torchref.base.fourier import get_real_grid, ifft
from torchref.base.reciprocal import extract_structure_factor_from_grid
from torchref.config import dtypes, get_default_device

from torchref.symmetry import Cell, SpaceGroup
from torchref.symmetry.map_symmetry import MapSymmetry
from torchref.symmetry.spacegroup import SpaceGroupLike
from torchref.utils.device_mixin import DeviceMovementMixin


[docs] class SfFFT(DeviceMovementMixin, nn.Module): """ Structure Factor calculator using FFT (Fast Fourier Transform). This module encapsulates all FFT-related functionality for computing electron density maps and structure factors. It is initialized with a Cell and optionally a SpaceGroup, which are used for grid calculations. Parameters ---------- cell : Cell Unit cell object containing cell parameters. spacegroup : SpaceGroupLike, optional Space group specification (string, int, or gemmi.SpaceGroup). If None, defaults to P1. max_res : float, optional Maximum resolution for grid spacing in Angstroms. Default is 1.0. radius_angstrom : float, optional Radius in Angstroms for density calculation around each atom. Default is 4.0. dtype_float : torch.dtype, optional Data type for floating point tensors. Default is dtypes.float. device : torch.device, optional Computation device. Defaults to the configured device.current. verbose : int, optional Verbosity level for logging. Default is 0. Attributes ---------- cell : Cell Unit cell object. spacegroup : SpaceGroup Space group object (SpaceGroup nn.Module with matrices and translations). symmetry : SpaceGroup Alias for spacegroup (backward compatibility). max_res : float Maximum resolution for grid spacing. radius_angstrom : float Radius for density calculation around each atom. gridsize : torch.Tensor or None Grid dimensions (nx, ny, nz) when grid is set up. real_space_grid : torch.Tensor or None Real-space coordinate grid with shape (nx, ny, nz, 3). voxel_size : torch.Tensor or None Voxel dimensions. map_symmetry : MapSymmetry or None Symmetry operator for map calculations. Examples -------- Standalone usage:: from torchref.symmetry import Cell cell = Cell([50, 60, 70, 90, 90, 90]) sf_fft = SfFFT(cell, spacegroup='P212121', max_res=1.5) sf_fft.setup_grid() density_map = sf_fft.build_density_map(xyz, b, occ, A, B, inv_frac, frac) sf = sf_fft.map_to_structure_factors(density_map, hkl) With ModelFT (composition):: model = ModelFT() model.load_pdb('structure.pdb') sf = model.get_structure_factor(hkl) # Uses internal SfFFT instance """
[docs] def __init__( self, cell: Optional[Cell] = None, spacegroup: SpaceGroupLike = None, max_res: float = 1.5, radius_angstrom: float = 3.0, dtype_float: torch.dtype = dtypes.float, device: Optional[torch.device] = None, verbose: int = 0, use_late_symmetry: bool = True, ): """ Initialize the SfFFT module with cell and spacegroup. Parameters ---------- cell : Cell, optional Unit cell object. If None, must be set later via set_cell(). spacegroup : SpaceGroupLike, optional Space group specification. If None, defaults to P1. max_res : float, optional Maximum resolution for grid spacing in Angstroms. Default is 1.5. radius_angstrom : float, optional Radius in Angstroms for density calculation. Default is 3.0. dtype_float : torch.dtype, optional Data type for floating point tensors. Default is dtypes.float. device : torch.device, optional Computation device. Default is None (uses cell's device). If Cell is also None, defaults to CPU. verbose : int, optional Verbosity level for logging. Default is 0. use_late_symmetry : bool, optional If True (default), apply symmetry in reciprocal space after FFT ("late symmetry") for faster structure factor calculation (~5x speedup). If False, apply symmetry to density map before FFT ("early symmetry"). """ super().__init__() self.max_res = max_res self.radius_angstrom = radius_angstrom self.dtype_float = dtype_float self.device = ( device if device is not None else cell.device if cell is not None else get_default_device() ) self.verbose = verbose self.use_late_symmetry = use_late_symmetry # Store cell and spacegroup self._cell = cell self._spacegroup = None if spacegroup is not None or cell is not None: self._spacegroup = SpaceGroup( spacegroup, dtype=dtype_float, device=device ) # Buffers (registered during setup_grid) self.register_buffer("gridsize", None) self.register_buffer("real_space_grid", None) self.register_buffer("voxel_size", None) # Map symmetry operator (set during setup_grid) self.map_symmetry: Optional[MapSymmetry] = None # Late symmetry compatibility flag (set during setup_grid) self._late_symmetry_compatible: Optional[bool] = None # Cached reciprocal symmetry extractor (precomputed flat indices) self._sym_extractor = None self._sym_extractor_hkl_id: Optional[int] = None
# ========================================================================= # Cell and SpaceGroup properties # ========================================================================= @property def cell(self) -> Optional[Cell]: """Unit cell object.""" return self._cell @cell.setter def cell(self, value: Cell): """Set unit cell.""" self._cell = value @property def spacegroup(self) -> Optional[SpaceGroup]: """Space group object (SpaceGroup nn.Module).""" return self._spacegroup @spacegroup.setter def spacegroup(self, value: SpaceGroupLike): """Set space group.""" if value is not None: self._spacegroup = SpaceGroup( value, dtype=self.dtype_float, device=self.device ) else: self._spacegroup = None @property def symmetry(self) -> Optional[SpaceGroup]: """Symmetry operations handler (alias for spacegroup).""" return self._spacegroup @property def fractional_matrix(self) -> Optional[torch.Tensor]: """Get fractionalization matrix from cell, on this module's device/dtype.""" if self._cell is not None: return self._cell.fractional_matrix.to( device=self.device, dtype=self.dtype_float ) return None @property def inv_fractional_matrix(self) -> Optional[torch.Tensor]: """Get orthogonalization matrix from cell, on this module's device/dtype.""" if self._cell is not None: return self._cell.inv_fractional_matrix.to( device=self.device, dtype=self.dtype_float ) return None
[docs] def set_cell_and_spacegroup(self, cell: Cell, spacegroup: SpaceGroupLike = None): """ Set cell and spacegroup for this SfFFT instance. Parameters ---------- cell : Cell Unit cell object. spacegroup : SpaceGroupLike, optional Space group specification. """ self._cell = cell self.spacegroup = spacegroup
# ========================================================================= # Grid Setup Methods # =========================================================================
[docs] def compute_optimal_gridsize(self, max_res: Optional[float] = None) -> tuple: """ Compute optimal grid dimensions using the stored cell and spacegroup. Uses Cell.compute_grid_size() for base calculation and Symmetry.suggest_grid_size() for symmetry optimization. Parameters ---------- max_res : float, optional Maximum resolution in Angstroms. If None, uses self.max_res. Returns ------- tuple of int Optimal grid dimensions (nx, ny, nz). Raises ------ RuntimeError If cell has not been set. """ if self._cell is None: raise RuntimeError("Cell not set. Call set_cell_and_spacegroup() first.") resolution = max_res if max_res is not None else self.max_res from torchref.symmetry.spacegroup import suggest_grid_size # Use Cell's method for base grid size calculation gridsize_initial = self._cell.compute_grid_size(resolution) if self.verbose > 1: print(f"Initial grid size from cell: {gridsize_initial}") # Optimize for symmetry and FFT-friendliness gridsize_optimized = suggest_grid_size( gridsize_initial, self._spacegroup, make_fft_friendly=True ) if self.verbose > 1 and gridsize_optimized != gridsize_initial: print( f"Optimized grid size from {gridsize_initial} to {gridsize_optimized} " f"(symmetry + FFT friendly)" ) return gridsize_optimized
[docs] @staticmethod def compute_real_space_grid( fractional_matrix: torch.Tensor, gridsize: torch.Tensor, device: torch.device = get_default_device(), ) -> torch.Tensor: """ Generate the real-space coordinate grid. Parameters ---------- cell_data : torch.Tensor Unit cell parameters [a, b, c, alpha, beta, gamma]. gridsize : torch.Tensor Grid dimensions (nx, ny, nz). device : torch.device, optional Target device. Default is CPU. Returns ------- torch.Tensor Real-space grid with shape (nx, ny, nz, 3). """ return get_real_grid(fractional_matrix=fractional_matrix, gridsize=gridsize, device=device)
[docs] def setup_grid( self, gridsize: Optional[Tuple[int, int, int]] = None, max_res: Optional[float] = None, ): """ Setup the real-space grid for electron density calculation. This method initializes and stores the grid state for subsequent density map calculations. Uses the stored cell and spacegroup. Parameters ---------- gridsize : tuple of int, optional Explicit grid size (nx, ny, nz). If None, computed automatically using Cell.compute_grid_size() and Symmetry.suggest_grid_size(). max_res : float, optional Maximum resolution in Angstroms. If None, uses self.max_res. Raises ------ RuntimeError If cell has not been set. """ if self._cell is None: raise RuntimeError("Cell not set. Call set_cell_and_spacegroup() first.") if max_res is not None: self.max_res = max_res if self.verbose > 1: print(f"Setting up grids with max_res={self.max_res} Å") # Compute or use provided grid size if gridsize is not None: self.gridsize = torch.tensor( gridsize, dtype=dtypes.int, device=self.device ) else: optimal_gridsize = self.compute_optimal_gridsize(self.max_res) self.gridsize = torch.tensor( optimal_gridsize, dtype=dtypes.int, device=self.device ) # Compute real space grid self.real_space_grid = self.compute_real_space_grid( self._cell.fractional_matrix, self.gridsize, self.device ) # Compute voxel size self.voxel_size = ( self.real_space_grid[2, 2, 2] - self.real_space_grid[1, 1, 1] ) # Initialize map symmetry operator if space group is set if self._spacegroup is not None: self.map_symmetry = MapSymmetry( space_group=self._spacegroup, map_shape=self.real_space_grid.shape[:-1], cell_params=self._cell.data, verbose=self.verbose, device=self.device, ) # Check late symmetry compatibility self._late_symmetry_compatible = self._check_late_symmetry_compatible() if self.use_late_symmetry and self._late_symmetry_compatible: if self.verbose > 0: print("SfFFT: Using late symmetry (reciprocal space) for ~5x speedup") elif self.use_late_symmetry and not self._late_symmetry_compatible: if self.verbose > 0: print( "SfFFT: Late symmetry disabled - grid not compatible " "(falling back to early symmetry)" ) else: self.map_symmetry = None self._late_symmetry_compatible = False # Invalidate cached symmetry extractor (grid shape changed) self._sym_extractor = None self._sym_extractor_hkl_id = None if self.verbose > 2: print(f"Grid shape: {self.real_space_grid.shape[:-1]}") print(f"Voxel size: {self.voxel_size}")
def _check_late_symmetry_compatible(self) -> bool: """ Check if late symmetry can be used (all equiv HKLs land on grid). Late symmetry requires that symmetry-equivalent HKL indices map to integer grid points. This is ensured when using MapSymmetryDirect (direct indexing without interpolation), which is created by the MapSymmetry factory when the grid is compatible. Returns ------- bool True if late symmetry is compatible, False otherwise. """ if self.map_symmetry is None: return False # Check if using MapSymmetryDirect (not interpolation) # MapSymmetryDirect has the can_use_direct_indexing attribute set to True from torchref.symmetry.map_symmetry import MapSymmetryDirect return isinstance(self.map_symmetry, MapSymmetryDirect) # ========================================================================= # Density Map Building Methods # =========================================================================
[docs] def build_density_map( self, xyz_iso: torch.Tensor, adp_iso: torch.Tensor, occ_iso: torch.Tensor, A_iso: torch.Tensor, B_iso: torch.Tensor, xyz_aniso: Optional[torch.Tensor] = None, u_aniso: Optional[torch.Tensor] = None, occ_aniso: Optional[torch.Tensor] = None, A_aniso: Optional[torch.Tensor] = None, B_aniso: Optional[torch.Tensor] = None, apply_symmetry: bool = True, ) -> torch.Tensor: """ Build electron density map from atomic parameters. This method requires `setup_grid()` to have been called first. Parameters ---------- xyz_iso : torch.Tensor Isotropic atom coordinates with shape (n_iso, 3). adp_iso : torch.Tensor Isotropic ADPs (atomic displacement parameters) with shape (n_iso,). occ_iso : torch.Tensor Isotropic occupancies with shape (n_iso,). A_iso : torch.Tensor ITC92 A parameters for isotropic atoms with shape (n_iso, 5). B_iso : torch.Tensor ITC92 B parameters for isotropic atoms with shape (n_iso, 5). xyz_aniso : torch.Tensor, optional Anisotropic atom coordinates with shape (n_aniso, 3). u_aniso : torch.Tensor, optional Anisotropic U parameters with shape (n_aniso, 6). occ_aniso : torch.Tensor, optional Anisotropic occupancies with shape (n_aniso,). A_aniso : torch.Tensor, optional ITC92 A parameters for anisotropic atoms with shape (n_aniso, 5). B_aniso : torch.Tensor, optional ITC92 B parameters for anisotropic atoms with shape (n_aniso, 5). apply_symmetry : bool, optional If True, apply crystallographic symmetry to the map. Default is True. Returns ------- torch.Tensor Electron density map with shape (nx, ny, nz). Raises ------ RuntimeError If setup_grid() has not been called. """ if self.real_space_grid is None: self.setup_grid() from torchref.base.electron_density.main import build_electron_density density_map = build_electron_density( real_space_grid=self.real_space_grid, xyz_iso=xyz_iso, adp_iso=adp_iso, occ_iso=occ_iso, A_iso=A_iso, B_iso=B_iso, inv_frac_matrix=self.inv_fractional_matrix, frac_matrix=self.fractional_matrix, radius_angstrom=self.radius_angstrom, voxel_size=self.voxel_size, xyz_aniso=xyz_aniso, u_aniso=u_aniso, occ_aniso=occ_aniso, A_aniso=A_aniso, B_aniso=B_aniso, dtype=self.dtype_float, ) # Apply symmetry if requested if apply_symmetry and self.map_symmetry is not None: density_map = self.map_symmetry(density_map) return density_map
# ========================================================================= # Structure Factor Methods # =========================================================================
[docs] def map_to_structure_factors( self, density_map: torch.Tensor, hkl: torch.Tensor, apply_symmetry: bool = True, ) -> torch.Tensor: """ Convert density map to structure factors via FFT. Parameters ---------- density_map : torch.Tensor Electron density map with shape (nx, ny, nz). If apply_symmetry=True, this should be a P1 density map. hkl : torch.Tensor Miller indices with shape (n_reflections, 3). apply_symmetry : bool, optional If True and late symmetry is enabled/compatible, apply symmetry in reciprocal space. Default is False (assume map already has symmetry applied or use early symmetry path). Returns ------- torch.Tensor Complex structure factors with shape (n_reflections,). """ reciprocal_space_grid = ifft(density_map, self.cell.volume) # Use late symmetry if enabled, compatible, and requested if apply_symmetry: # Lazily build / reuse cached extractor (precomputed flat indices) if self._sym_extractor is None or id(hkl) != self._sym_extractor_hkl_id: from torchref.base.reciprocal import ReciprocalSymmetryExtractor grid_shape = tuple(int(x) for x in self.gridsize) self._sym_extractor = ReciprocalSymmetryExtractor( hkl, self.spacegroup, grid_shape, device=reciprocal_space_grid.device, ) self._sym_extractor_hkl_id = id(hkl) return self._sym_extractor.extract_from_grid(reciprocal_space_grid) else: return extract_structure_factor_from_grid(reciprocal_space_grid, hkl)
[docs] def compute_structure_factors( self, hkl: torch.Tensor, xyz_iso: torch.Tensor, adp_iso: torch.Tensor, occ_iso: torch.Tensor, A_iso: torch.Tensor, B_iso: torch.Tensor, xyz_aniso: Optional[torch.Tensor] = None, u_aniso: Optional[torch.Tensor] = None, occ_aniso: Optional[torch.Tensor] = None, A_aniso: Optional[torch.Tensor] = None, B_aniso: Optional[torch.Tensor] = None, apply_symmetry: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute structure factors from atomic parameters (end-to-end). This is a convenience method that builds the density map and computes structure factors in one call. When use_late_symmetry=True (default) and the grid is compatible, symmetry is applied in reciprocal space after FFT for ~5x speedup. Otherwise, symmetry is applied to the density map before FFT. Parameters ---------- hkl : torch.Tensor Miller indices with shape (n_reflections, 3). xyz_iso : torch.Tensor Isotropic atom coordinates with shape (n_iso, 3). adp_iso : torch.Tensor Isotropic ADPs (atomic displacement parameters) with shape (n_iso,). occ_iso : torch.Tensor Isotropic occupancies with shape (n_iso,). A_iso : torch.Tensor ITC92 A parameters for isotropic atoms. B_iso : torch.Tensor ITC92 B parameters for isotropic atoms. xyz_aniso : torch.Tensor, optional Anisotropic atom coordinates. u_aniso : torch.Tensor, optional Anisotropic U parameters. occ_aniso : torch.Tensor, optional Anisotropic occupancies. A_aniso : torch.Tensor, optional ITC92 A parameters for anisotropic atoms. B_aniso : torch.Tensor, optional ITC92 B parameters for anisotropic atoms. apply_symmetry : bool, optional If True, apply crystallographic symmetry. Default is True. Returns ------- sf : torch.Tensor Complex structure factors with shape (n_reflections,). density_map : torch.Tensor Electron density map with shape (nx, ny, nz). Note: When using late symmetry, this is the P1 map (without symmetry). """ # Decide symmetry strategy: # - Late symmetry: build P1 map, apply symmetry in reciprocal space # - Early symmetry: apply symmetry to density map before FFT use_late = ( apply_symmetry and self.use_late_symmetry and self._late_symmetry_compatible ) # Build density map (with or without early symmetry) density_map = self.build_density_map( xyz_iso=xyz_iso, adp_iso=adp_iso, occ_iso=occ_iso, A_iso=A_iso, B_iso=B_iso, xyz_aniso=xyz_aniso, u_aniso=u_aniso, occ_aniso=occ_aniso, A_aniso=A_aniso, B_aniso=B_aniso, apply_symmetry=not use_late and apply_symmetry, # Early symmetry ) # Extract structure factors (with or without late symmetry) sf = self.map_to_structure_factors( density_map, hkl, apply_symmetry=use_late, # Late symmetry ) return sf, density_map
# ========================================================================= # Device Movement # =========================================================================
[docs] def reset_cache(self) -> None: """Drop the cached symmetry extractor; recomputed on next use.""" self._sym_extractor = None self._sym_extractor_hkl_id = None
[docs] def copy(self) -> "SfFFT": """Create a deep copy of this SfFFT module. Returns ------- SfFFT A new SfFFT instance with cloned cell, spacegroup, and buffers. """ # Clone the cell new_cell = self._cell.clone() if self._cell is not None else None # Copy the spacegroup new_spacegroup = self._spacegroup.copy() if self._spacegroup is not None else None # Create new SfFFT with copied components new_fft = SfFFT( cell=new_cell, spacegroup=new_spacegroup, max_res=self.max_res, radius_angstrom=self.radius_angstrom, dtype_float=self.dtype_float, device=self.device, verbose=self.verbose, use_late_symmetry=self.use_late_symmetry, ) return new_fft
# Backward compatibility alias — deprecated, use SfFFT directly
[docs] def FFT(*args, **kwargs): """Deprecated: use SfFFT instead.""" import warnings warnings.warn( "FFT is deprecated, use SfFFT instead. " "FFT will be removed in a future release.", DeprecationWarning, stacklevel=2, ) return SfFFT(*args, **kwargs)