Source code for torchref.symmetry.reciprocal_symmetry

"""
Reciprocal space symmetry operations for structure factor grids.

This module provides efficient symmetry operations for reciprocal space data (h, k, l grids),
analogous to map_symmetry.py for real space density maps.

Key concepts:
- Miller indices transform as h' = h @ R^T under rotation R
- Systematic absences occur when translation causes destructive interference
- Centric reflections have phases restricted to 0 or π
- Friedel pairs: F(h,k,l) = F*(-h,-k,-l) for normal scattering

Main interfaces:
- ReciprocalSymmetry: Factory function for grid-based reciprocal space symmetry
- expand_reflections: Expand ReflectionData from asymmetric unit to P1
- expand_reciprocal_grid: Expand a reciprocal space grid from asymmetric unit to P1

Space groups can be specified as strings, integers (1-230), or gemmi.SpaceGroup objects.
"""

from typing import TYPE_CHECKING, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn

from torchref.config import get_default_device, get_float_dtype
from torchref.symmetry.spacegroup import SpaceGroup, SpaceGroupLike
from torchref.utils.device_mixin import DeviceMixin

if TYPE_CHECKING:
    from torchref.io.datasets.reflection_data import ReflectionData


[docs] def ReciprocalSymmetry( space_group: SpaceGroupLike, grid_shape, dtype_float=get_float_dtype(), verbose=1, device=get_default_device(), ): """ Factory function to create the appropriate ReciprocalSymmetry implementation. Parameters ---------- space_group : str, int, or gemmi.SpaceGroup Space group specification (e.g., 'P21', 4, gemmi.SpaceGroup('P 21')). grid_shape : tuple of int Shape of the reciprocal space grid (nh, nk, nl). The grid spans from -n//2 to n//2 for each dimension. dtype_float : torch.dtype, default: configured dtypes.float Floating point precision to use. verbose : int, default 1 Verbosity level (0=silent, 1=info, 2=debug). device : torch.device, default: configured device.current Device to use for computation. Returns ------- ReciprocalSymmetryGrid Implementation for reciprocal space grid symmetry operations. """ return ReciprocalSymmetryGrid(space_group, grid_shape, dtype_float, verbose, device)
[docs] class ReciprocalSymmetryGrid(DeviceMixin, nn.Module): """ Reciprocal space symmetry operations for Miller index grids. Handles symmetry operations in reciprocal space including: - Miller index transformation under symmetry operations - Systematic absence detection - Centric reflection identification - Friedel pair handling - Symmetry expansion/averaging of structure factors In reciprocal space, symmetry operations transform Miller indices as: (h', k', l') = (h, k, l) @ R^T where R is the rotation matrix from real space symmetry. Attributes ---------- space_group : str Space group name. grid_shape : tuple of int Shape of the reciprocal space grid (nh, nk, nl). symmetry : SpaceGroup Base symmetry operations handler. n_ops : int Number of symmetry operations. Examples -------- :: recip_sym = ReciprocalSymmetry('P21', grid_shape=(64, 64, 64)) F_expanded = recip_sym(F_asym) # Expand from asymmetric unit F_avg = recip_sym.symmetry_average(F_full) # Average symmetry-related reflections """
[docs] def __init__( self, space_group, grid_shape, dtype_float=get_float_dtype(), verbose=1, device=get_default_device(), ): """ Initialize reciprocal space symmetry operator. Parameters ---------- space_group : str Space group name. grid_shape : tuple of int Shape of the reciprocal space grid (nh, nk, nl). dtype_float : torch.dtype, default: configured dtypes.float Floating point precision. verbose : int, default 1 Verbosity level. device : torch.device, default: configured device.current Computation device. """ super().__init__() self.dtype_float = dtype_float self.space_group = space_group self.grid_shape = tuple(grid_shape) self.verbose = verbose self.device = device # Get symmetry operations from base class self.symmetry = SpaceGroup(space_group, dtype=dtype_float, device=device) self.n_ops = self.symmetry.matrices.shape[0] # Precompute reciprocal space rotation matrices (transpose of real space) # In reciprocal space: h' = R^T @ h (or equivalently h' = h @ R) self._setup_reciprocal_matrices() if self.verbose > 0: print(f"ReciprocalSymmetryGrid initialized for {space_group}") print(f" Number of symmetry operations: {self.n_ops}") print(f" Grid shape: {self.grid_shape}") # Setup Miller index grid self._setup_hkl_grid() # Precompute index mappings for each symmetry operation self._setup_symmetry_index_grids() # Compute systematic absences mask self._setup_systematic_absences() # Compute centric reflection mask self._setup_centric_reflections() if self.verbose > 0: n_absent = self.systematic_absences.sum().item() n_centric = self.centric_mask.sum().item() n_total = np.prod(self.grid_shape) print(f" Systematic absences: {n_absent} ({100*n_absent/n_total:.2f}%)") print(f" Centric reflections: {n_centric} ({100*n_centric/n_total:.2f}%)")
def _setup_reciprocal_matrices(self): """ Setup rotation matrices for reciprocal space. In reciprocal space, the transformation is h' = R^T @ h where R is the real-space rotation matrix. """ # Transpose each rotation matrix for reciprocal space # Real space: r' = R @ r + t # Reciprocal space: h' = R^T @ h (translations cause phase shifts, not index changes) recip_matrices = self.symmetry.matrices.transpose(-2, -1).contiguous() self.register_buffer("reciprocal_matrices", recip_matrices) def _setup_hkl_grid(self): """ Setup Miller index grid. Creates a grid where indices span from -n//2 to n//2 for each dimension, following the FFT convention (0...n//2, -n//2+1...-1). """ nh, nk, nl = self.grid_shape # Create index arrays following FFT convention h = torch.fft.fftfreq(nh, d=1.0) * nh # gives 0,1,2,...,n//2,-n//2+1,...,-1 k = torch.fft.fftfreq(nk, d=1.0) * nk l = torch.fft.fftfreq(nl, d=1.0) * nl h = h.to(dtype=torch.int64, device=self.device) k = k.to(dtype=torch.int64, device=self.device) l = l.to(dtype=torch.int64, device=self.device) # Create 3D grid of Miller indices grid_h, grid_k, grid_l = torch.meshgrid(h, k, l, indexing="ij") hkl_grid = torch.stack([grid_h, grid_k, grid_l], dim=-1) self.register_buffer("hkl_grid", hkl_grid) # Also store as float for matrix operations hkl_grid_float = hkl_grid.to(dtype=self.dtype_float) self.register_buffer("hkl_grid_float", hkl_grid_float) def _setup_symmetry_index_grids(self): """ Precompute index grids mapping each (h,k,l) to its symmetry equivalents. For each symmetry operation, computes where each reflection maps to in the grid, allowing efficient gathering/scattering operations. """ nh, nk, nl = self.grid_shape hkl_flat = self.hkl_grid_float.reshape(-1, 3) # (N, 3) index_grids_list = [] phase_shift_grids_list = [] for i in range(self.n_ops): # Transform Miller indices: h' = R^T @ h # Using batch matrix multiply: (N, 3) @ (3, 3) -> (N, 3) transformed = torch.matmul(hkl_flat, self.reciprocal_matrices[i].T) # Round to nearest integer (should be exact for valid symmetry ops) transformed_int = torch.round(transformed).to(torch.int64) # Compute phase shift from translation: phase = 2π * h · t # This phase modifies the structure factor: F(h') = F(h) * exp(2πi h·t) translation = self.symmetry.translations[i] phase_shift = 2.0 * np.pi * torch.matmul(hkl_flat, translation) phase_shift = phase_shift.reshape(nh, nk, nl) phase_shift_grids_list.append(phase_shift) # Convert to grid indices (wrap with periodic boundary) idx_h = transformed_int[:, 0] % nh idx_k = transformed_int[:, 1] % nk idx_l = transformed_int[:, 2] % nl # Stack indices index_grid = torch.stack([idx_h, idx_k, idx_l], dim=-1) index_grid = index_grid.reshape(nh, nk, nl, 3) index_grids_list.append(index_grid) # Stack all grids: (n_ops, nh, nk, nl, 3) index_grids = torch.stack(index_grids_list, dim=0) self.register_buffer("index_grids", index_grids) # Stack phase shifts: (n_ops, nh, nk, nl) phase_shifts = torch.stack(phase_shift_grids_list, dim=0) self.register_buffer("phase_shifts", phase_shifts) def _setup_systematic_absences(self): """ Compute mask for systematic absences. A reflection (h,k,l) is systematically absent if for some symmetry operation with translation t: h' = h (maps to itself) AND h·t ≠ 0 (mod 1). This causes destructive interference from the translation component. """ nh, nk, nl = self.grid_shape absences = torch.zeros(self.grid_shape, dtype=torch.bool, device=self.device) hkl_flat = self.hkl_grid_float.reshape(-1, 3) for i in range(self.n_ops): # Transform indices transformed = torch.matmul(hkl_flat, self.reciprocal_matrices[i].T) transformed_int = torch.round(transformed).to(torch.int64) # Check if h' ≡ h (reflection maps to itself) hkl_int = self.hkl_grid.reshape(-1, 3) same_reflection = (transformed_int == hkl_int).all(dim=-1) # Compute phase shift h·t translation = self.symmetry.translations[i] h_dot_t = torch.matmul(hkl_flat, translation) # Check if phase is non-integer (mod 1 ≠ 0) # A reflection is absent if it maps to itself with non-zero phase phase_mod = torch.abs(h_dot_t - torch.round(h_dot_t)) non_zero_phase = phase_mod > 1e-6 # Mark as absent absent_mask = (same_reflection & non_zero_phase).reshape(self.grid_shape) absences = absences | absent_mask self.register_buffer("systematic_absences", absences) def _setup_centric_reflections(self): """ Compute mask for centric reflections. A reflection is centric if there exists a symmetry operation that maps h to -h (including Friedel mate consideration). Centric reflections have phases restricted to 0 or π. """ nh, nk, nl = self.grid_shape centric = torch.zeros(self.grid_shape, dtype=torch.bool, device=self.device) hkl_flat = self.hkl_grid_float.reshape(-1, 3) for i in range(self.n_ops): # Transform indices transformed = torch.matmul(hkl_flat, self.reciprocal_matrices[i].T) transformed_int = torch.round(transformed).to(torch.int64) # Check if h' ≡ -h (maps to Friedel mate) hkl_int = self.hkl_grid.reshape(-1, 3) maps_to_minus_h = (transformed_int == -hkl_int).all(dim=-1) centric_mask = maps_to_minus_h.reshape(self.grid_shape) centric = centric | centric_mask self.register_buffer("centric_mask", centric)
[docs] def apply_to_indices(self, hkl, operation_index=None): """ Apply symmetry operation(s) to Miller indices. Parameters ---------- hkl : torch.Tensor, shape (..., 3) Miller indices (h, k, l). operation_index : int, optional If specified, apply only this operation. If None, apply all operations. Returns ------- torch.Tensor Transformed Miller indices. If operation_index is None: shape (n_ops, ..., 3) Otherwise: shape (..., 3) """ hkl = hkl.to(dtype=self.dtype_float, device=self.device) original_shape = hkl.shape[:-1] if operation_index is not None: # Apply single operation R = self.reciprocal_matrices[operation_index] transformed = torch.matmul(hkl, R.T) return torch.round(transformed).to(torch.int64) else: # Apply all operations hkl_flat = hkl.reshape(-1, 3) # (N, 3) results = [] for i in range(self.n_ops): R = self.reciprocal_matrices[i] transformed = torch.matmul(hkl_flat, R.T) results.append(torch.round(transformed).to(torch.int64)) # Stack: (n_ops, N, 3) stacked = torch.stack(results, dim=0) return stacked.reshape(self.n_ops, *original_shape, 3)
[docs] def get_phase_shift(self, hkl, operation_index): """ Get phase shift for a symmetry operation on given Miller indices. The phase shift is exp(2πi h·t) where t is the translation. Parameters ---------- hkl : torch.Tensor, shape (..., 3) Miller indices. operation_index : int Symmetry operation index. Returns ------- torch.Tensor Phase shift in radians, shape (...). """ hkl = hkl.to(dtype=self.dtype_float, device=self.device) translation = self.symmetry.translations[operation_index] phase = 2.0 * np.pi * torch.matmul(hkl, translation) return phase
[docs] def get_symmetry_mate(self, F_grid, operation_index): """ Apply a single symmetry operation to a structure factor grid. Parameters ---------- F_grid : torch.Tensor, shape (nh, nk, nl) Complex structure factor grid. operation_index : int Index of the symmetry operation (0 to n_ops-1). Returns ------- torch.Tensor, shape (nh, nk, nl) Structure factors after applying symmetry operation. Includes phase shift from translation component. """ if operation_index < 0 or operation_index >= self.n_ops: raise ValueError( f"Operation index {operation_index} out of range [0, {self.n_ops-1}]" ) if F_grid.shape != self.grid_shape: raise ValueError( f"Grid shape {F_grid.shape} doesn't match expected {self.grid_shape}" ) # Get precomputed index grid for this operation idx_grid = self.index_grids[operation_index] # (nh, nk, nl, 3) # Gather structure factors from transformed positions F_transformed = F_grid[idx_grid[..., 0], idx_grid[..., 1], idx_grid[..., 2]] # Apply phase shift from translation: F(h') = F(h) * exp(2πi h·t) phase = self.phase_shifts[operation_index] if F_grid.is_complex(): phase_factor = torch.exp(1j * phase.to(F_grid.dtype)) F_transformed = F_transformed * phase_factor # For real-valued grids (amplitudes), no phase shift needed return F_transformed
[docs] def get_all_symmetry_mates(self, F_grid): """ Get all symmetry-related structure factor grids. Parameters ---------- F_grid : torch.Tensor, shape (nh, nk, nl) Complex structure factor grid. Returns ------- list of torch.Tensor List of symmetry-related grids. """ return [self.get_symmetry_mate(F_grid, i) for i in range(self.n_ops)]
[docs] def symmetry_average(self, F_grid, weights=None): """ Average structure factors over all symmetry equivalents. This is useful for enforcing symmetry constraints on structure factors. Parameters ---------- F_grid : torch.Tensor, shape (nh, nk, nl) Complex structure factor grid. weights : torch.Tensor, optional Weights for averaging, shape (nh, nk, nl). If None, equal weights are used. Returns ------- torch.Tensor, shape (nh, nk, nl) Symmetry-averaged structure factors. """ mates = self.get_all_symmetry_mates(F_grid) stacked = torch.stack(mates, dim=0) # (n_ops, nh, nk, nl) if weights is not None: weights = weights.unsqueeze(0) # (1, nh, nk, nl) stacked = stacked * weights return stacked.sum(dim=0) / (weights.sum() * self.n_ops) else: return stacked.mean(dim=0)
[docs] def expand_to_p1(self, F_asym, asym_mask=None): """ Expand structure factors from asymmetric unit to full P1. Takes structure factors defined on the asymmetric unit and generates the full reciprocal space by applying all symmetry operations. Parameters ---------- F_asym : torch.Tensor, shape (nh, nk, nl) Structure factors on asymmetric unit (other positions can be zero). asym_mask : torch.Tensor, optional Boolean mask indicating asymmetric unit positions. If None, non-zero values are assumed to be the asymmetric unit. Returns ------- torch.Tensor, shape (nh, nk, nl) Full structure factor grid with all symmetry equivalents filled. """ F_full = F_asym.clone() for i in range(1, self.n_ops): # Skip identity F_mate = self.get_symmetry_mate(F_asym, i) # Only fill in positions that are zero (not yet set) if F_full.is_complex(): mask = F_full.abs() < 1e-10 else: mask = F_full.abs() < 1e-10 F_full = torch.where(mask, F_mate, F_full) return F_full
[docs] def apply_friedel(self, F_grid): """ Apply Friedel's law: F(-h,-k,-l) = F*(h,k,l). For normal (non-anomalous) scattering, the structure factor at -h is the complex conjugate of F(h). Parameters ---------- F_grid : torch.Tensor, shape (nh, nk, nl) Complex structure factor grid. Returns ------- torch.Tensor, shape (nh, nk, nl) Structure factors with Friedel symmetry enforced. """ # Flip all indices: F(-h,-k,-l) F_friedel = torch.flip(F_grid, dims=[0, 1, 2]) # Roll to handle the asymmetry at 0 index nh, nk, nl = self.grid_shape F_friedel = torch.roll(F_friedel, shifts=(1, 1, 1), dims=(0, 1, 2)) if F_grid.is_complex(): F_friedel = F_friedel.conj() # Average F(h) and F*(-h) return 0.5 * (F_grid + F_friedel)
[docs] def is_systematic_absence(self, h, k, l): """ Check if a reflection is systematically absent. Parameters ---------- h, k, l : int Miller indices. Returns ------- bool True if the reflection is systematically absent. """ nh, nk, nl = self.grid_shape idx_h = h % nh idx_k = k % nk idx_l = l % nl return self.systematic_absences[idx_h, idx_k, idx_l].item()
[docs] def is_centric(self, h, k, l): """ Check if a reflection is centric. Parameters ---------- h, k, l : int Miller indices. Returns ------- bool True if the reflection is centric (phase restricted to 0 or π). """ nh, nk, nl = self.grid_shape idx_h = h % nh idx_k = k % nk idx_l = l % nl return self.centric_mask[idx_h, idx_k, idx_l].item()
[docs] def get_epsilon(self): """ Compute epsilon (multiplicity) factors for each reflection. Epsilon is the number of symmetry operations that map h to itself (or to its Friedel mate for acentric space groups). Returns ------- torch.Tensor, shape (nh, nk, nl) Epsilon factors for each reflection. """ epsilon = torch.zeros(self.grid_shape, dtype=torch.int32, device=self.device) hkl_flat = self.hkl_grid.reshape(-1, 3) for i in range(self.n_ops): # Get transformed indices idx_grid = self.index_grids[i] transformed_flat = idx_grid.reshape(-1, 3) # Check if h' ≡ h or h' ≡ -h (same reflection or Friedel) same = (transformed_flat == hkl_flat).all(dim=-1) nh, nk, nl = self.grid_shape # Also check Friedel (-h, -k, -l) hkl_neg = (-self.hkl_grid).reshape(-1, 3) hkl_neg[:, 0] = hkl_neg[:, 0] % nh hkl_neg[:, 1] = hkl_neg[:, 1] % nk hkl_neg[:, 2] = hkl_neg[:, 2] % nl friedel = (transformed_flat == hkl_neg).all(dim=-1) contributes = (same | friedel).reshape(self.grid_shape) epsilon += contributes.to(torch.int32) return epsilon
[docs] def forward(self, F_grid, mode="average"): """ Apply symmetry to structure factor grid. Parameters ---------- F_grid : torch.Tensor, shape (nh, nk, nl) Complex structure factor grid. mode : str, default 'average' Operation mode: - 'average': Average over all symmetry equivalents - 'expand': Expand from asymmetric unit to full grid - 'sum': Sum all symmetry mates (for accumulation) Returns ------- torch.Tensor, shape (nh, nk, nl) Processed structure factor grid. """ if mode == "average": return self.symmetry_average(F_grid) elif mode == "expand": return self.expand_to_p1(F_grid) elif mode == "sum": mates = self.get_all_symmetry_mates(F_grid) return torch.stack(mates, dim=0).sum(dim=0) else: raise ValueError( f"Unknown mode: {mode}. Use 'average', 'expand', or 'sum'." )
[docs] def __call__(self, F_grid, mode="average"): """Make the class callable.""" return self.forward(F_grid, mode=mode)
[docs] def get_symmetry_info(self): """ Get information about reciprocal space symmetry. Returns ------- dict Dictionary with symmetry information. """ return { "space_group": self.space_group, "n_operations": self.n_ops, "reciprocal_matrices": self.reciprocal_matrices, "translations": self.symmetry.translations, "n_systematic_absences": self.systematic_absences.sum().item(), "n_centric": self.centric_mask.sum().item(), "grid_shape": self.grid_shape, }
def __repr__(self): return ( f"ReciprocalSymmetryGrid(space_group='{self.space_group}', " f"n_ops={self.n_ops}, grid_shape={self.grid_shape})" )
# ============================================================================= # Standalone functions for symmetry expansion # =============================================================================
[docs] def expand_hkl( hkl: torch.Tensor, spacegroup: SpaceGroupLike, include_friedel: bool = True, remove_absences: bool = True, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Expand Miller indices under crystallographic symmetry. This is the core low-level function for HKL expansion. It takes Miller indices and a space group, and returns the expanded indices along with an index mapping and phase offsets needed to expand any associated data. Parameters ---------- hkl : torch.Tensor, shape (N, 3) Input Miller indices (asymmetric unit). spacegroup : str, int, or gemmi.SpaceGroup Space group specification. include_friedel : bool, default True Include Friedel mates (-h, -k, -l). remove_absences : bool, default True Remove systematically absent reflections. device : torch.device, optional Computation device. If None, uses hkl's device. Returns ------- expanded_hkl : torch.Tensor, shape (M, 3), dtype=int32 All unique expanded Miller indices. orig_indices : torch.Tensor, shape (M,), dtype=int64 Index mapping expanded → original reflection. Use to expand any data: ``F_expanded = F_orig[orig_indices]`` phase_shifts : torch.Tensor, shape (M,), dtype=float32 Phase offsets from translations (radians). Apply to phases: ``phase_expanded = phase_orig[orig_indices] + phase_shifts`` Examples -------- :: import torch from torchref.symmetry import expand_hkl hkl_asu = torch.tensor([[1, 0, 0], [0, 1, 0], [1, 1, 1]], dtype=torch.int32) hkl_p1, indices, phases = expand_hkl(hkl_asu, 'P21') # Expand amplitude data F_asu = torch.tensor([100.0, 80.0, 75.0]) F_p1 = F_asu[indices] # Expand phase data (apply phase shifts) phi_asu = torch.tensor([0.0, 1.5, 2.0]) phi_p1 = phi_asu[indices] + phases """ if device is None: device = hkl.device # Get symmetry operations symmetry = SpaceGroup(spacegroup, dtype=get_float_dtype(), device=device) n_ops = symmetry.matrices.shape[0] # Reciprocal space matrices (transpose of real space) recip_matrices = symmetry.matrices.transpose(-2, -1) translations = symmetry.translations # Convert hkl to float for matrix operations hkl_float = hkl.to(dtype=get_float_dtype(), device=device) n_orig = len(hkl_float) # Apply all symmetry operations all_hkl = [] all_phases = [] for i in range(n_ops): # h' = h @ R^T hkl_transformed = torch.round(torch.matmul(hkl_float, recip_matrices[i].T)).to( torch.int32 ) # Phase shift from translation: 2π × h·t phase_shift = 2.0 * np.pi * torch.matmul(hkl_float, translations[i]) all_hkl.append(hkl_transformed) all_phases.append(phase_shift) # Add Friedel mates if requested if include_friedel: for i in range(n_ops): all_hkl.append(-all_hkl[i]) all_phases.append(-all_phases[i]) # Stack all transformed hkl and phases hkl_expanded = torch.cat(all_hkl, dim=0) phases_expanded = torch.cat(all_phases, dim=0) # Remove duplicates - keep unique (h,k,l) tuples with index mapping hkl_np = hkl_expanded.cpu().numpy() phase_np = phases_expanded.cpu().numpy() # Build dictionary: key=(h,k,l), value=(first_occurrence_idx, phase) unique_dict = {} for idx, (h, phase) in enumerate(zip(hkl_np, phase_np)): key = tuple(h) if key not in unique_dict: unique_dict[key] = (idx, phase) # Extract unique data unique_indices = [v[0] for v in unique_dict.values()] unique_phases = [v[1] for v in unique_dict.values()] # Map back to original reflection index n_total_ops = n_ops * (2 if include_friedel else 1) orig_indices = [idx % n_orig for idx in unique_indices] # Build output tensors expanded_hkl = torch.tensor( [list(k) for k in unique_dict.keys()], dtype=torch.int32, device=device ) phase_shifts = torch.tensor(unique_phases, dtype=get_float_dtype(), device=device) orig_idx_tensor = torch.tensor(orig_indices, dtype=torch.int64, device=device) # Remove systematic absences if requested sg = SpaceGroup(spacegroup) is_p1 = sg.number == 1 if remove_absences and not is_p1: absence_mask = _check_systematic_absences( expanded_hkl, symmetry.matrices, translations, device ) keep_mask = ~absence_mask expanded_hkl = expanded_hkl[keep_mask] phase_shifts = phase_shifts[keep_mask] orig_idx_tensor = orig_idx_tensor[keep_mask] return expanded_hkl, orig_idx_tensor, phase_shifts
[docs] def complete_hkl( input_hkl: torch.Tensor, cell: torch.Tensor, spacegroup: SpaceGroupLike, d_min: float, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Complete a set of Miller indices by identifying missing reflections. Generates all possible reflections within the resolution limit for the given spacegroup (removing systematic absences), then maps the input reflections to this complete set. NOTE: This does NOT expand symmetry - stays in the same spacegroup. Use this to identify which reflections are missing from a dataset. Parameters ---------- input_hkl : torch.Tensor, shape (N, 3) Input Miller indices (may be incomplete). cell : torch.Tensor, shape (6,) Unit cell parameters [a, b, c, alpha, beta, gamma]. spacegroup : str, int, or gemmi.SpaceGroup Space group specification. d_min : float High resolution limit in Angstroms. device : torch.device, optional Computation device. If None, uses input_hkl's device. Returns ------- complete_hkl : torch.Tensor, shape (M, 3), dtype int32 All possible Miller indices within resolution (minus systematic absences). input_indices : torch.Tensor, shape (M,), dtype int64 Index mapping: complete → input. For present reflections, gives the index in input_hkl. For missing reflections, gives -1. Use: ``F_complete[~missing] = F_input[input_indices[~missing]]`` missing_mask : torch.Tensor, shape (M,), dtype bool True where reflection is missing from input. Examples -------- :: import torch from torchref.symmetry import complete_hkl # Incomplete dataset input_hkl = torch.tensor([[1, 0, 0], [0, 1, 0]], dtype=torch.int32) cell = torch.tensor([50.0, 60.0, 70.0, 90.0, 90.0, 90.0]) complete, indices, missing = complete_hkl(input_hkl, cell, 'P21', d_min=10.0) # Fill F values F_input = torch.tensor([100.0, 80.0]) F_complete = torch.zeros(len(complete)) present_mask = ~missing F_complete[present_mask] = F_input[indices[present_mask]] """ from torchref.base.reciprocal import generate_possible_hkl if device is None: device = input_hkl.device # Generate all possible HKL within resolution all_hkl = generate_possible_hkl(cell, d_min, device=device) # Get symmetry operations for absence check symmetry = SpaceGroup(spacegroup, dtype=get_float_dtype(), device=device) translations = symmetry.translations # Remove systematic absences sg = SpaceGroup(spacegroup) is_p1 = sg.number == 1 if not is_p1: absence_mask = _check_systematic_absences( all_hkl, symmetry.matrices, translations, device ) all_hkl = all_hkl[~absence_mask] # Build lookup dictionary from input hkl to indices input_hkl_np = input_hkl.cpu().numpy() input_lookup = {} for idx, hkl in enumerate(input_hkl_np): key = tuple(hkl) input_lookup[key] = idx # Match complete set to input all_hkl_np = all_hkl.cpu().numpy() n_complete = len(all_hkl) input_indices = torch.full((n_complete,), -1, dtype=torch.int64, device=device) missing_mask = torch.ones(n_complete, dtype=torch.bool, device=device) for i, hkl in enumerate(all_hkl_np): key = tuple(hkl) if key in input_lookup: input_indices[i] = input_lookup[key] missing_mask[i] = False return all_hkl, input_indices, missing_mask
[docs] def expand_reflections( reflection_data: "ReflectionData", include_friedel: bool = True, remove_absences: bool = True, verbose: int = 1, ) -> "ReflectionData": """ Expand reflection data from asymmetric unit to P1 using symmetry operations. Takes a ReflectionData object containing reflections in the asymmetric unit and generates all symmetry-equivalent reflections, returning a new ReflectionData object with the expanded set. This is a high-level wrapper around expand_hkl() that handles all ReflectionData fields automatically. Parameters ---------- reflection_data : ReflectionData Input reflection data with hkl, F, F_sigma, etc. include_friedel : bool, default True If True, also include Friedel mates (-h, -k, -l). remove_absences : bool, default True If True, remove systematically absent reflections from output. verbose : int, default 1 Verbosity level. Returns ------- ReflectionData New ReflectionData object with expanded reflections. The spacegroup is set to 'P1' since symmetry has been expanded. See Also -------- expand_hkl : Low-level function for HKL expansion without ReflectionData. Examples -------- :: from torchref.io.datasets.reflection_data import ReflectionData from torchref.symmetry import expand_reflections data = ReflectionData().load_mtz('data.mtz') data_p1 = expand_reflections(data) print(f"Expanded from {len(data)} to {len(data_p1)} reflections") """ from torchref.io.datasets.reflection_data import ReflectionData as RefData if reflection_data.hkl is None: raise ValueError("ReflectionData has no Miller indices loaded") space_group = reflection_data.spacegroup or "P1" device = reflection_data.device n_orig = len(reflection_data.hkl) if verbose > 0: symmetry = SpaceGroup(space_group, dtype=get_float_dtype(), device=device) print(f"Expanding reflections for {space_group}") print(f" Original reflections: {n_orig}") print(f" Symmetry operations: {symmetry.n_ops}") # Use low-level expand_hkl for core expansion hkl_expanded, orig_idx_tensor, phase_shifts = expand_hkl( reflection_data.hkl, space_group, include_friedel=include_friedel, remove_absences=remove_absences, device=device, ) if verbose > 0: print(f" After expansion: {len(hkl_expanded)} unique reflections") # Create new ReflectionData expanded = RefData(verbose=reflection_data.verbose, device=device) # Copy and expand data fields using orig_idx_tensor expanded.hkl = hkl_expanded expanded.cell = ( reflection_data.cell.clone() if reflection_data.cell is not None else None ) expanded.spacegroup = SpaceGroup("P1") # Now in P1 since symmetry is expanded # Expand amplitude data if reflection_data.F is not None: expanded.F = reflection_data.F[orig_idx_tensor] if reflection_data.F_sigma is not None: expanded.F_sigma = reflection_data.F_sigma[orig_idx_tensor] if reflection_data.I is not None: expanded.I = reflection_data.I[orig_idx_tensor] if hasattr(reflection_data, "I_sigma") and reflection_data.I_sigma is not None: expanded.I_sigma = reflection_data.I_sigma[orig_idx_tensor] # Handle phases - apply phase shift from translation if hasattr(reflection_data, "phase") and reflection_data.phase is not None: expanded.phase = reflection_data.phase[orig_idx_tensor] + phase_shifts else: # Store phase shifts for potential later use expanded._expansion_phase_shifts = phase_shifts # Figure of merit - expand from original if hasattr(reflection_data, "fom") and reflection_data.fom is not None: expanded.fom = reflection_data.fom[orig_idx_tensor] # R-free flags - expand from original if reflection_data.rfree_flags is not None: expanded.rfree_flags = reflection_data.rfree_flags[orig_idx_tensor] # Recalculate resolution if expanded.cell is not None: expanded._calculate_resolution() # Clear bin indices (invalidated by expansion) expanded.bin_indices = None # Copy metadata expanded.amplitude_source = reflection_data.amplitude_source expanded.intensity_source = reflection_data.intensity_source expanded.phase_source = reflection_data.phase_source expanded.rfree_source = reflection_data.rfree_source # Track provenance expanded.source = reflection_data expanded.last_op = f"expand_to_p1(include_friedel={include_friedel})" return expanded
def _check_systematic_absences( hkl: torch.Tensor, matrices: torch.Tensor, translations: torch.Tensor, device: torch.device, ) -> torch.Tensor: """ Check which reflections are systematically absent. A reflection is absent if some symmetry operation maps h to h with a non-integer phase shift (h·t not integer). Parameters ---------- hkl : torch.Tensor, shape (N, 3) Miller indices to check. matrices : torch.Tensor, shape (n_ops, 3, 3) Rotation matrices. translations : torch.Tensor, shape (n_ops, 3) Translation vectors. device : torch.device Computation device. Returns ------- torch.Tensor Boolean mask, True for systematically absent reflections. """ n_refl = len(hkl) n_ops = matrices.shape[0] absent = torch.zeros(n_refl, dtype=torch.bool, device=device) hkl_float = hkl.to(dtype=get_float_dtype(), device=device) recip_matrices = matrices.transpose(-2, -1).to(dtype=get_float_dtype(), device=device) translations = translations.to(dtype=get_float_dtype(), device=device) for i in range(n_ops): R = recip_matrices[i] t = translations[i] # Transform: h' = h @ R^T hkl_transformed = torch.matmul(hkl_float, R.T) hkl_transformed_int = torch.round(hkl_transformed).to(torch.int32) # Check if h' == h (maps to itself) same_reflection = (hkl_transformed_int == hkl).all(dim=-1) # Compute phase shift h·t h_dot_t = torch.matmul(hkl_float, t) # Check if phase is non-integer (indicates absence) phase_mod = torch.abs(h_dot_t - torch.round(h_dot_t)) non_integer_phase = phase_mod > 1e-6 # Absent if maps to itself with non-integer phase absent = absent | (same_reflection & non_integer_phase) return absent
[docs] def reduce_hkl( hkl_p1: torch.Tensor, spacegroup: SpaceGroupLike, include_friedel: bool = True, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Reduce P1 Miller indices to asymmetric unit of a target spacegroup. This is the inverse of expand_hkl(). Takes a complete set of P1 reflections and maps them back to the asymmetric unit of the target spacegroup. Multiple P1 reflections that are symmetry-equivalent merge into single ASU reflections. The return uses a 2D index mapping with "constant multiplicity" design: each ASU reflection has exactly n_ops slots (for symmetry operations), enabling simple vectorized aggregation. Parameters ---------- hkl_p1 : torch.Tensor, shape (N, 3) Input Miller indices in P1 (complete hemisphere). spacegroup : str, int, or gemmi.SpaceGroup Target space group specification. include_friedel : bool, default True If True, also consider Friedel mates when finding ASU representative. device : torch.device, optional Computation device. If None, uses hkl_p1's device. Returns ------- hkl_asu : torch.Tensor, shape (M, 3), dtype int32 Unique Miller indices in the asymmetric unit. reduction_indices : torch.Tensor, shape (M, n_ops * (2 if include_friedel else 1)), dtype int64 For each ASU reflection, indices into hkl_p1 for its symmetry equivalents. Value of -1 indicates no P1 reflection at that position. Use: ``F_asu = aggregate(F_p1[reduction_indices], dim=1)`` phase_shifts : torch.Tensor, shape (M, n_ops * (2 if include_friedel else 1)), dtype float32 Phase shifts to apply before aggregation. For Friedel mates, the phase is negated. Examples -------- :: import torch from torchref.symmetry import expand_hkl, reduce_hkl # Start with ASU, expand to P1, then reduce back hkl_asu = torch.tensor([[1, 0, 0], [0, 1, 0], [1, 1, 1]], dtype=torch.int32) hkl_p1, exp_idx, exp_phase = expand_hkl(hkl_asu, 'P21') # Reduce back to ASU hkl_asu_back, red_idx, red_phase = reduce_hkl(hkl_p1, 'P21') # Aggregate F values from P1 to ASU F_p1 = torch.randn(len(hkl_p1)) valid_mask = red_idx >= 0 F_gathered = torch.where(valid_mask, F_p1[red_idx.clamp(min=0)], torch.zeros_like(F_p1[0])) F_asu = F_gathered.sum(dim=1) / valid_mask.sum(dim=1).clamp(min=1) Notes ----- The "constant multiplicity" design means the second dimension is always n_ops (or 2*n_ops with Friedel), regardless of whether all equivalent reflections exist in hkl_p1. Missing positions are marked with -1. This enables efficient batch aggregation without variable-length operations. """ if device is None: device = hkl_p1.device # Get symmetry operations symmetry = SpaceGroup(spacegroup, dtype=get_float_dtype(), device=device) n_ops = symmetry.matrices.shape[0] # Reciprocal space matrices (transpose of real space) recip_matrices = symmetry.matrices.transpose(-2, -1) translations = symmetry.translations # Total number of equivalent positions per ASU reflection n_equiv = n_ops * (2 if include_friedel else 1) # Convert hkl to float for matrix operations hkl_float = hkl_p1.to(dtype=get_float_dtype(), device=device) n_p1 = len(hkl_float) # Build lookup from hkl tuple to index in P1 array hkl_p1_np = hkl_p1.cpu().numpy() p1_lookup = {tuple(h): idx for idx, h in enumerate(hkl_p1_np)} # For each P1 reflection, find its "canonical" ASU representative # The canonical form is the lexicographically smallest (h, k, l) among all equivalents def get_canonical_hkl(hkl_single): """Find canonical ASU representative for a single reflection.""" equivalents = [] for i in range(n_ops): # h' = h @ R^T hkl_trans = torch.round(torch.matmul(hkl_single, recip_matrices[i].T)).to( torch.int32 ) equivalents.append(hkl_trans) if include_friedel: equivalents.append(-hkl_trans) # Stack and find lexicographically smallest equiv_stack = torch.stack(equivalents, dim=0) # Sort by (h, k, l) lexicographically # Convert to tuple for comparison equiv_np = equiv_stack.cpu().numpy() equiv_tuples = [tuple(e) for e in equiv_np] canonical = min(equiv_tuples) return canonical # Map each P1 reflection to its canonical ASU representative p1_to_asu = {} # maps P1 index to canonical ASU tuple asu_reflections = ( {} ) # maps canonical ASU tuple to list of (P1_idx, phase_shift, equiv_idx) for p1_idx, hkl_single in enumerate(hkl_float): canonical = get_canonical_hkl(hkl_single) p1_to_asu[p1_idx] = canonical if canonical not in asu_reflections: asu_reflections[canonical] = [] # Find which equivalent this P1 reflection corresponds to for equiv_idx in range(n_ops): R = recip_matrices[equiv_idx] t = translations[equiv_idx] hkl_trans = torch.round(torch.matmul(hkl_single, R.T)).to(torch.int32) phase_shift = 2.0 * np.pi * torch.matmul(hkl_single, t) if tuple(hkl_trans.cpu().numpy()) == canonical: asu_reflections[canonical].append( (p1_idx, phase_shift.item(), equiv_idx) ) break if include_friedel: if tuple((-hkl_trans).cpu().numpy()) == canonical: # Friedel mate: phase is negated asu_reflections[canonical].append( (p1_idx, -phase_shift.item(), equiv_idx + n_ops) ) break # Build output tensors asu_list = sorted(asu_reflections.keys()) n_asu = len(asu_list) hkl_asu = torch.tensor(asu_list, dtype=torch.int32, device=device) reduction_indices = torch.full( (n_asu, n_equiv), -1, dtype=torch.int64, device=device ) phase_shifts = torch.zeros((n_asu, n_equiv), dtype=get_float_dtype(), device=device) # Fill in the indices and phase shifts for asu_idx, asu_hkl in enumerate(asu_list): for p1_idx, phase, equiv_idx in asu_reflections[asu_hkl]: reduction_indices[asu_idx, equiv_idx] = p1_idx phase_shifts[asu_idx, equiv_idx] = phase return hkl_asu, reduction_indices, phase_shifts
def _asu_condition_vectorized(h, k, l, condition_key): """Vectorized CCP4 ASU membership check. Parameters ---------- h, k, l : np.ndarray, shape ``(...,)`` Miller index components (any broadcastable shape). condition_key : str ASU condition identifier (CCP4 condition string from ``gemmi.ReciprocalAsu.condition_str()``). Returns ------- np.ndarray, same shape as input, dtype bool True where (h, k, l) is inside the CCP4 reciprocal ASU. """ # Map the 10 distinct CCP4 ASU conditions (covers all 230 space groups). _conditions = { # Laue -1 (triclinic) "l>0 or (l=0 and (h>0 or (h=0 and k>=0)))": lambda h, k, l: (l > 0) | ((l == 0) & ((h > 0) | ((h == 0) & (k >= 0)))), # Laue 2/m (monoclinic) "k>=0 and (l>0 or (l=0 and h>=0))": lambda h, k, l: (k >= 0) & ((l > 0) | ((l == 0) & (h >= 0))), # Laue mmm (orthorhombic) "h>=0 and k>=0 and l>=0": lambda h, k, l: (h >= 0) & (k >= 0) & (l >= 0), # Laue 4/m, 6/m (tetragonal, hexagonal) "l>=0 and ((h>=0 and k>0) or (h=0 and k=0))": lambda h, k, l: (l >= 0) & (((h >= 0) & (k > 0)) | ((h == 0) & (k == 0))), # Laue 4/mmm, 6/mmm "h>=k and k>=0 and l>=0": lambda h, k, l: (h >= k) & (k >= 0) & (l >= 0), # Laue -3 (trigonal, no mirror) "(h>=0 and k>0) or (h=0 and k=0 and l>=0)": lambda h, k, l: ((h >= 0) & (k > 0)) | ((h == 0) & (k == 0) & (l >= 0)), # Laue -3m, P312 variant "h>=k and k>=0 and (k>0 or l>=0)": lambda h, k, l: (h >= k) & (k >= 0) & ((k > 0) | (l >= 0)), # Laue -3m, P321 variant "h>=k and k>=0 and (h>k or l>=0)": lambda h, k, l: (h >= k) & (k >= 0) & ((h > k) | (l >= 0)), # Laue m-3 (cubic) "h>=0 and ((l>=h and k>h) or (l=h and k=h))": lambda h, k, l: (h >= 0) & (((l >= h) & (k > h)) | ((l == h) & (k == h))), # Laue m-3m (cubic, full symmetry) "k>=l and l>=h and h>=0": lambda h, k, l: (k >= l) & (l >= h) & (h >= 0), } fn = _conditions.get(condition_key) if fn is None: raise ValueError(f"Unknown ASU condition: {condition_key}") return fn(h, k, l)
[docs] def canonicalize_hkl( hkl: torch.Tensor, spacegroup: SpaceGroupLike, include_friedel: bool = True, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Map Miller indices to canonical CCP4 ASU representatives. Uses the standard CCP4 asymmetric unit convention to select a unique representative for each reflection. The implementation is fully vectorized — symmetry equivalents are generated via batched matrix multiply and the ASU membership check is evaluated as a numpy boolean expression over all reflections at once. Parameters ---------- hkl : torch.Tensor, shape (N, 3), dtype int32 Input Miller indices. spacegroup : str, int, or gemmi.SpaceGroup Space group specification. include_friedel : bool, default True Whether Friedel mates are considered equivalent. device : torch.device, optional Computation device. If None, uses hkl's device. Returns ------- canonical_hkl : torch.Tensor, shape (N, 3), dtype int32 Remapped indices, sorted lexicographically by (h, k, l). phase_shifts : torch.Tensor, shape (N,), dtype float32 Additive phase correction in radians. friedel_flags : torch.Tensor, shape (N,), dtype bool True where Friedel conjugation was applied. sort_indices : torch.Tensor, shape (N,), dtype int64 Permutation from original to sorted order. Notes ----- Phase correction contract — to convert structure factors to the canonical basis, the caller applies:: phi_new = torch.where(friedel_flags, -phi_old, phi_old) + phase_shifts """ import gemmi if device is None: device = hkl.device hkl_dtype = hkl.dtype n_refl = len(hkl) if n_refl == 0: empty_hkl = torch.empty((0, 3), dtype=hkl_dtype, device=device) empty_f = torch.empty(0, dtype=get_float_dtype(), device=device) empty_b = torch.empty(0, dtype=torch.bool, device=device) empty_i = torch.empty(0, dtype=torch.int64, device=device) return empty_hkl, empty_f, empty_b, empty_i # Normalize spacegroup (CPU-only: this branch builds numpy-backed lookup tables) sg_obj = SpaceGroup(spacegroup, dtype=get_float_dtype(), device=torch.device("cpu")) sg_gemmi = sg_obj._gemmi asu = gemmi.ReciprocalAsu(sg_gemmi) condition_key = asu.condition_str() recip_mats = sg_obj.matrices.transpose(-2, -1).numpy() # (n_ops, 3, 3) translations_np = sg_obj.translations.numpy() # (n_ops, 3) n_ops = len(recip_mats) hkl_np = hkl.cpu().numpy().astype(np.int32) # (N, 3) # Reciprocal-space rotation matrices are always integer-valued (0, ±1). recip_mats_i = np.round(recip_mats).astype(np.int32) # --- Early-exit loop: process one op (+ Friedel) at a time --- # For high-symmetry groups this avoids computing equivalents for ops # that are never needed (most reflections are resolved by the first # few operators). Benchmarks show 2-3x speedup for n_ops >= 6. canonical_np = np.empty_like(hkl_np) op_idx = np.empty(n_refl, dtype=np.int32) friedel_np = np.zeros(n_refl, dtype=bool) remaining = np.ones(n_refl, dtype=bool) for i_op in range(n_ops): if not remaining.any(): break idx = np.where(remaining)[0] hkl_sub = hkl_np[idx] # (M, 3) R = recip_mats_i[i_op] # (3, 3) equiv_sub = hkl_sub @ R.T # (M, 3), int32 matmul — no rounding needed # Check non-Friedel h, k, l = equiv_sub[:, 0], equiv_sub[:, 1], equiv_sub[:, 2] try: in_asu_pos = _asu_condition_vectorized(h, k, l, condition_key) except ValueError: in_asu_pos = np.array( [asu.is_in(row.tolist()) for row in equiv_sub], dtype=bool ) hit_pos = np.where(in_asu_pos)[0] if len(hit_pos) > 0: global_idx = idx[hit_pos] canonical_np[global_idx] = equiv_sub[hit_pos] op_idx[global_idx] = i_op remaining[global_idx] = False # Check Friedel mate if include_friedel and remaining.any(): # Recompute idx for remaining after non-Friedel hits idx_f = np.where(remaining)[0] hkl_sub_f = hkl_np[idx_f] equiv_neg = -(hkl_sub_f @ R.T) h_n, k_n, l_n = equiv_neg[:, 0], equiv_neg[:, 1], equiv_neg[:, 2] try: in_asu_neg = _asu_condition_vectorized(h_n, k_n, l_n, condition_key) except ValueError: in_asu_neg = np.array( [asu.is_in(row.tolist()) for row in equiv_neg], dtype=bool ) hit_neg = np.where(in_asu_neg)[0] if len(hit_neg) > 0: global_idx_f = idx_f[hit_neg] canonical_np[global_idx_f] = equiv_neg[hit_neg] op_idx[global_idx_f] = i_op friedel_np[global_idx_f] = True remaining[global_idx_f] = False # --- Compute phase shifts vectorially --- # phase_shift[i] = 2*pi * hkl_orig[i] . translations[op_idx[i]] t_selected = translations_np[op_idx] # (N, 3) phase_shifts_np = ( 2.0 * np.pi * np.sum(hkl_np.astype(np.float32) * t_selected, axis=1) ).astype(np.float32) # --- Convert to tensors and sort --- canonical_hkl = torch.tensor(canonical_np, dtype=hkl_dtype, device=device) phase_shifts = torch.tensor(phase_shifts_np, dtype=get_float_dtype(), device=device) friedel_flags = torch.tensor(friedel_np, dtype=torch.bool, device=device) # Lexicographic sort by (h, k, l) via composite key h_max = int(canonical_hkl.abs().max().item()) + 1 base = 2 * h_max + 1 sort_key = ( canonical_hkl[:, 0].to(torch.int64) * base * base + canonical_hkl[:, 1].to(torch.int64) * base + canonical_hkl[:, 2].to(torch.int64) ) sort_indices = torch.argsort(sort_key) return ( canonical_hkl[sort_indices], phase_shifts[sort_indices], friedel_flags[sort_indices], sort_indices, )
[docs] def expand_reciprocal_grid( F_grid: torch.Tensor, space_group: str, mode: str = "average", include_friedel: bool = True, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Expand or symmetrize a reciprocal space grid using crystallographic symmetry. This is a convenience function that creates a ReciprocalSymmetryGrid and applies it to the input grid. Parameters ---------- F_grid : torch.Tensor, shape (nh, nk, nl) Input structure factor grid (can be complex or real). space_group : str Space group symbol (e.g., 'P21', 'P212121'). mode : str, default 'average' Operation mode: - 'average': Average over all symmetry equivalents (symmetrize) - 'expand': Expand from asymmetric unit to full grid - 'sum': Sum all symmetry mates include_friedel : bool, default True If True, also apply Friedel symmetry after space group symmetry. device : torch.device, optional Device for computation. If None, uses F_grid's device. Returns ------- torch.Tensor, shape (nh, nk, nl) Symmetrized or expanded structure factor grid. Examples -------- :: import torch from torchref.symmetry import expand_reciprocal_grid # Create a test grid with some values in asymmetric unit F = torch.zeros(32, 32, 32, dtype=torch.complex64) F[5, 3, 2] = 1.0 + 0.5j # Expand to full grid F_full = expand_reciprocal_grid(F, 'P21', mode='expand') # Or symmetrize an existing full grid F_sym = expand_reciprocal_grid(F_noisy, 'P21', mode='average') """ if device is None: device = F_grid.device grid_shape = F_grid.shape dtype = get_float_dtype() if not F_grid.is_complex() else F_grid.real.dtype # Create symmetry handler recip_sym = ReciprocalSymmetryGrid( space_group=space_group, grid_shape=grid_shape, dtype_float=dtype, verbose=0, device=device, ) # Apply symmetry operation F_result = recip_sym(F_grid, mode=mode) # Apply Friedel symmetry if requested if include_friedel: F_result = recip_sym.apply_friedel(F_result) return F_result