Source code for torchref.base.reciprocal.symmetry

"""
Reciprocal space symmetry application for structure factor calculation.

This module provides an alternative to the MapSymmetry-based approach for
structure factor calculation. Instead of symmetrizing the density map in
real space ("early symmetry"), we apply symmetry directly in reciprocal space
after FFT ("late symmetry").

Late symmetry is approximately 5x faster for structure factor calculation
because it avoids the expensive map symmetrization step.

Mathematical Background
-----------------------
For a symmetry operation {R|t} (rotation R, translation t):

    F_sym(h) = sum_ops exp(2*pi*i * h.t) * F_P1(R^T @ h)

Where:
- F_sym(h) = structure factor with symmetry at Miller index h
- F_P1(h') = P1 structure factor at h' = R^T @ h
- exp(2*pi*i * h.t) = translation phase shift

Since R is an integer matrix (crystallographic), R^T @ h gives integer indices
that can be extracted directly from the reciprocal space grid.

Terminology
-----------
- Early symmetry: Apply symmetry to density map before FFT (MapSymmetry approach)
- Late symmetry: Apply symmetry in reciprocal space after FFT (this module)
"""

from typing import Optional, TYPE_CHECKING

import numpy as np
import torch

from torchref.config import get_float_dtype
from torchref.utils.autograd_ops import gather_with_index_add

from .grid_operations import extract_structure_factor_from_grid

if TYPE_CHECKING:
    from torchref.symmetry.spacegroup import SpaceGroup


[docs] def compute_symmetry_equivalent_hkls( hkl: torch.Tensor, rotation_matrices: torch.Tensor, ) -> torch.Tensor: """ Compute symmetry-equivalent HKLs: h' = R^T @ h for each operation. In reciprocal space, Miller indices transform as h' = h @ R (or equivalently h' = R^T @ h when treating h as a column vector). Parameters ---------- hkl : torch.Tensor, shape (N, 3) Miller indices. rotation_matrices : torch.Tensor, shape (n_ops, 3, 3) Real-space rotation matrices. These are transposed internally for reciprocal space transformation. Returns ------- torch.Tensor, shape (n_ops, N, 3) Equivalent HKLs for each symmetry operation. """ device = hkl.device dtype = get_float_dtype() # Ensure correct types hkl_float = hkl.to(dtype=dtype, device=device) # (N, 3) # For reciprocal space: F_sym(h) = sum_ops exp(2πi h·t) * F_P1(R^T @ h) # With row vector convention: h' = h @ R equals (R^T @ h) in column notation # So we use rotation_matrices directly (no transpose) rot_matrices = rotation_matrices.to(dtype=dtype, device=device) # (n_ops, 3, 3) n_ops = rot_matrices.shape[0] # Compute h' = h @ R for each operation # hkl_float: (N, 3) -> (1, N, 3) # rot_matrices: (n_ops, 3, 3) # Result: (n_ops, N, 3) hkl_expanded = hkl_float.unsqueeze(0).expand(n_ops, -1, -1) # (n_ops, N, 3) # Batch matrix multiply: (n_ops, N, 3) @ (n_ops, 3, 3) -> (n_ops, N, 3) equiv_hkl = torch.bmm(hkl_expanded, rot_matrices) # Round to nearest integer (should be exact for valid crystallographic ops) equiv_hkl = torch.round(equiv_hkl).to(torch.int64) return equiv_hkl
[docs] def compute_translation_phases( hkl: torch.Tensor, translations: torch.Tensor, ) -> torch.Tensor: """ Compute phase shifts: exp(2*pi*i * h.t) for each operation. The translation component of a symmetry operation causes a phase shift in the structure factor. Parameters ---------- hkl : torch.Tensor, shape (N, 3) Miller indices. translations : torch.Tensor, shape (n_ops, 3) Translation vectors in fractional coordinates. Returns ------- torch.Tensor, shape (n_ops, N) Complex phase factors exp(2*pi*i * h.t). """ device = hkl.device dtype = get_float_dtype() # Ensure correct types hkl_float = hkl.to(dtype=dtype, device=device) # (N, 3) translations = translations.to(dtype=dtype, device=device) # (n_ops, 3) # Compute h.t for each operation # hkl_float: (N, 3), translations: (n_ops, 3) # We want: (n_ops, N) where each entry is h.t # h.t = hkl_float @ translations.T -> (N, n_ops) # Then transpose to get (n_ops, N) h_dot_t = torch.matmul(hkl_float, translations.T).T # (n_ops, N) # Phase factor: exp(2*pi*i * h.t) phase = 2.0 * np.pi * h_dot_t phase_factor = torch.exp(1j * phase.to(torch.float32)) return phase_factor # (n_ops, N) complex
[docs] def extract_structure_factors_with_symmetry( reciprocal_grid: torch.Tensor, hkl: torch.Tensor, rotation_matrices: torch.Tensor, translations: torch.Tensor, ) -> torch.Tensor: """ Extract structure factors with symmetry applied in reciprocal space. This is the main function that replaces the MapSymmetry approach for structure factor extraction. Instead of symmetrizing the density map and then extracting F(hkl), we extract F at all symmetry-equivalent positions and sum with phases. Parameters ---------- reciprocal_grid : torch.Tensor, shape (Nx, Ny, Nz) Complex reciprocal space grid from FFT of P1 density map. hkl : torch.Tensor, shape (N, 3) Target Miller indices. rotation_matrices : torch.Tensor, shape (n_ops, 3, 3) Real-space rotation matrices from symmetry operations. translations : torch.Tensor, shape (n_ops, 3) Translation vectors from symmetry operations. Returns ------- torch.Tensor, shape (N,) Complex structure factors with symmetry applied. """ device = reciprocal_grid.device Nx, Ny, Nz = reciprocal_grid.shape # Move everything to the same device hkl = hkl.to(device=device) rotation_matrices = rotation_matrices.to(device=device) translations = translations.to(device=device) n_ops = rotation_matrices.shape[0] N = hkl.shape[0] # 1. Compute equivalent HKLs: (n_ops, N, 3) equiv_hkls = compute_symmetry_equivalent_hkls(hkl, rotation_matrices) # 2. Vectorized extraction: single flat gather for all symops at once. # Use gather_with_index_add so the gradient back into reciprocal_grid # is a single index_add_ (no radix-sort + dedup scatter). flat_indices = _equiv_hkls_to_flat_indices(equiv_hkls, Nx, Ny, Nz) f_all = gather_with_index_add( reciprocal_grid.reshape(-1), flat_indices, ) # (n_ops * N,) f_p1 = f_all.view(n_ops, N) # 3. Compute phase shifts: (n_ops, N) phases = compute_translation_phases(hkl, translations) # 4. Apply phases and sum: (N,) f_sym = (f_p1 * phases).sum(dim=0) return f_sym
def _equiv_hkls_to_flat_indices( equiv_hkls: torch.Tensor, Nx: int, Ny: int, Nz: int, ) -> torch.Tensor: """Convert (n_ops, N, 3) equiv HKLs to flat linear grid indices.""" all_hkl = equiv_hkls.reshape(-1, 3) # (n_ops*N, 3) hi = torch.remainder(all_hkl[:, 0], Nx) ki = torch.remainder(all_hkl[:, 1], Ny) li = torch.remainder(all_hkl[:, 2], Nz) return (hi * (Ny * Nz) + ki * Nz + li).to(torch.int64) from torchref.utils.device_mixin import DeviceMixin
[docs] class ReciprocalSymmetryExtractor(DeviceMixin): """ Class-based interface for reciprocal space symmetry extraction. This provides a more efficient interface when computing structure factors multiple times with the same symmetry and HKLs (e.g., during refinement). Precomputes equivalent HKLs, phase factors, and flat grid indices so that each call reduces to a single gather + multiply + sum (~3 GPU kernels instead of ~28). Parameters ---------- hkl : torch.Tensor, shape (N, 3) Target Miller indices. symmetry : SpaceGroup SpaceGroup object containing rotation matrices and translations. grid_shape : tuple of int Reciprocal grid dimensions (Nx, Ny, Nz). device : torch.device, optional Device for computation. Examples -------- >>> extractor = ReciprocalSymmetryExtractor(hkl, symmetry, grid_shape=(209, 86, 67)) >>> f_calc = extractor.extract_from_grid(reciprocal_grid) """
[docs] def __init__( self, hkl: torch.Tensor, symmetry: "SpaceGroup", grid_shape: tuple, device: Optional[torch.device] = None, ): self.device = device or hkl.device self.hkl = hkl.to(device=self.device) self.symmetry = symmetry self.n_ops = symmetry.n_ops self.N = len(hkl) self.grid_shape = grid_shape # Precompute equivalent HKLs self.equiv_hkls = compute_symmetry_equivalent_hkls( self.hkl, symmetry.matrices.to(device=self.device), ) # (n_ops, N, 3) # Precompute phase factors self.phases = compute_translation_phases( self.hkl, symmetry.translations.to(device=self.device), ) # (n_ops, N) complex # Precompute flat linear indices for single-gather extraction Nx, Ny, Nz = grid_shape self._flat_indices = _equiv_hkls_to_flat_indices( self.equiv_hkls, Nx, Ny, Nz, ) # (n_ops * N,) int64
[docs] def __call__(self, density_map: torch.Tensor) -> torch.Tensor: """ Compute structure factors from P1 density map. Parameters ---------- density_map : torch.Tensor, shape (Nx, Ny, Nz) P1 electron density map (NO symmetry applied). Returns ------- torch.Tensor, shape (N,) Complex structure factors with symmetry applied. """ return self.extract(density_map)
[docs] def extract(self, density_map: torch.Tensor) -> torch.Tensor: """ Extract structure factors from P1 density map. Parameters ---------- density_map : torch.Tensor, shape (Nx, Ny, Nz) P1 electron density map (NO symmetry applied). Returns ------- torch.Tensor, shape (N,) Complex structure factors with symmetry applied. """ # FFT the density map from torchref.base.fourier.fft import ifft reciprocal_grid = ifft(density_map) # Extract from grid return self.extract_from_grid(reciprocal_grid)
[docs] def extract_from_grid(self, reciprocal_grid: torch.Tensor) -> torch.Tensor: """ Extract structure factors from precomputed reciprocal grid. Uses precomputed flat indices for a single vectorized gather, avoiding per-symop Python loops and kernel launches. Parameters ---------- reciprocal_grid : torch.Tensor, shape (Nx, Ny, Nz) Complex reciprocal space grid from FFT. Returns ------- torch.Tensor, shape (N,) Complex structure factors with symmetry applied. """ # Gather via custom autograd op so the backward is a single # ``index_add_`` (atomic scatter, no radix sort + dedup). f_all = gather_with_index_add( reciprocal_grid.reshape(-1), self._flat_indices, ) # (n_ops * N,) f_sym = (f_all.view(self.n_ops, self.N) * self.phases).sum(dim=0) return f_sym