Source code for torchref.model.sf_ds

"""
SfDS - Structure Factor calculation via Direct Summation.

This module provides a PyTorch nn.Module that handles:
- Direct summation structure factor calculations
- Support for both isotropic and anisotropic atoms
- Crystallographic symmetry handling

The SfDS class provides an alternative to SfFFT for computing structure
factors without requiring a grid-based FFT approach.
"""

from typing import Optional, Tuple

import torch
import torch.nn as nn

from torchref.base.direct_summation import (
    iso_structure_factor_torched,
    aniso_structure_factor_torched,
)
from torchref.base.reciprocal import (
    get_scattering_vectors,
    reciprocal_basis_matrix,
)
from torchref.config import dtypes, get_default_device
from torchref.symmetry import Cell, SpaceGroup
from torchref.symmetry.spacegroup import SpaceGroupLike
from torchref.utils.device_mixin import DeviceMovementMixin


[docs] class SfDS(DeviceMovementMixin, nn.Module): """ Structure Factor calculator using Direct Summation. This module computes structure factors by directly summing atomic contributions without building an intermediate electron density map. It is initialized with a Cell and optionally a SpaceGroup. Includes automatic batching to handle memory constraints for large structures or high-resolution data. Parameters ---------- cell : Cell, optional Unit cell object containing cell parameters. spacegroup : SpaceGroupLike, optional Space group specification (string, int, or gemmi.SpaceGroup). If None, defaults to P1. dtype_float : torch.dtype, optional Data type for floating point tensors. Default is dtypes.float. device : torch.device, optional Computation device. Defaults to the configured device.current. verbose : int, optional Verbosity level for logging. Default is 0. max_memory_gb : float, optional Maximum memory to use for intermediate tensors in GB. Default is 2.0. Set to None to disable batching. Attributes ---------- cell : Cell Unit cell object. spacegroup : SpaceGroup Space group object (SpaceGroup nn.Module with matrices and translations). Examples -------- Standalone usage:: from torchref.symmetry import Cell cell = Cell([50, 60, 70, 90, 90, 90]) sf_ds = SfDS(cell, spacegroup='P212121') sf, _ = sf_ds.compute_structure_factors( hkl, xyz_iso, adp_iso, occ_iso, A_iso, B_iso ) With memory limit for large structures:: sf_ds = SfDS(cell, spacegroup='P212121', max_memory_gb=4.0) sf, _ = sf_ds.compute_structure_factors(...) # Auto-batches if needed Notes ----- Key differences from SfFFT: - No grid setup required - No build_density_map() or map_to_structure_factors() methods - Computes scattering factors internally from A/B ITC92 coefficients - Returns (sf, None) instead of (sf, density_map) - Automatic batching for memory management """
[docs] def __init__( self, cell: Optional[Cell] = None, spacegroup: SpaceGroupLike = None, dtype_float: torch.dtype = dtypes.float, device: torch.device = get_default_device(), verbose: int = 0, max_memory_gb: float = 2.0, ): """ Initialize the SfDS module with cell and spacegroup. Parameters ---------- cell : Cell, optional Unit cell object. If None, must be set later. spacegroup : SpaceGroupLike, optional Space group specification. If None, defaults to P1. dtype_float : torch.dtype, optional Data type for floating point tensors. Default is dtypes.float. device : torch.device, optional Computation device. Defaults to the configured device.current. verbose : int, optional Verbosity level for logging. Default is 0. max_memory_gb : float, optional Maximum memory for intermediate tensors in GB. Default is 2.0. """ super().__init__() self.dtype_float = dtype_float self.device = device self.verbose = verbose self.max_memory_gb = max_memory_gb # Store cell and spacegroup self._cell = cell self._spacegroup = None if spacegroup is not None or cell is not None: self._spacegroup = SpaceGroup( spacegroup, dtype=dtype_float, device=device ) # Cache reciprocal basis matrix self._recB: Optional[torch.Tensor] = None
# ========================================================================= # Cell and SpaceGroup properties # ========================================================================= @property def cell(self) -> Optional[Cell]: """Unit cell object.""" return self._cell @cell.setter def cell(self, value: Cell): """Set unit cell and invalidate cached reciprocal basis matrix.""" self._cell = value self._recB = None # Invalidate cache @property def spacegroup(self) -> Optional[SpaceGroup]: """Space group object (SpaceGroup nn.Module).""" return self._spacegroup @spacegroup.setter def spacegroup(self, value: SpaceGroupLike): """Set space group.""" if value is not None: self._spacegroup = SpaceGroup( value, dtype=self.dtype_float, device=self.device ) else: self._spacegroup = None @property def fractional_matrix(self) -> Optional[torch.Tensor]: """Get fractionalization matrix from cell.""" if self._cell is not None: return self._cell.fractional_matrix return None @property def inv_fractional_matrix(self) -> Optional[torch.Tensor]: """Get orthogonalization matrix from cell.""" if self._cell is not None: return self._cell.inv_fractional_matrix return None
[docs] def set_cell_and_spacegroup(self, cell: Cell, spacegroup: SpaceGroupLike = None): """ Set cell and spacegroup for this SfDS instance. Parameters ---------- cell : Cell Unit cell object. spacegroup : SpaceGroupLike, optional Space group specification. """ self._cell = cell self._recB = None # Invalidate cache self.spacegroup = spacegroup
# ========================================================================= # Internal helper methods # ========================================================================= def _get_reciprocal_basis_matrix(self) -> torch.Tensor: """ Get or compute the reciprocal basis matrix. Returns ------- torch.Tensor Reciprocal basis matrix of shape (3, 3) with a*, b*, c* as rows. Raises ------ RuntimeError If cell has not been set. """ if self._cell is None: raise RuntimeError("Cell not set. Call set_cell_and_spacegroup() first.") if self._recB is None: self._recB = reciprocal_basis_matrix(self._cell.data) return self._recB def _compute_scattering_factors( self, s: torch.Tensor, A: torch.Tensor, B: torch.Tensor ) -> torch.Tensor: """ Compute atomic scattering factors from ITC92 A and B coefficients. The scattering factor is computed as: f(s) = sum_i A_i * exp(-B_i * s^2 / 4) Parameters ---------- s : torch.Tensor Scattering vector magnitudes of shape (N_reflections,). A : torch.Tensor ITC92 A parameters (amplitudes) with shape (N_atoms, 5). B : torch.Tensor ITC92 B parameters (widths) with shape (N_atoms, 5). Returns ------- torch.Tensor Atomic scattering factors of shape (N_reflections, N_atoms). """ # s: (N_refl,) -> (N_refl, 1, 1) # A, B: (N_atoms, 5) s_sq = (s.reshape(-1, 1, 1) ** 2) / 4 # (N_refl, 1, 1) # B: (N_atoms, 5) -> (1, N_atoms, 5) B_expanded = B.unsqueeze(0) # (1, N_atoms, 5) A_expanded = A.unsqueeze(0) # (1, N_atoms, 5) # Compute exponential terms: (N_refl, N_atoms, 5) exp_terms = torch.exp(-B_expanded * s_sq) # Sum over Gaussian components: (N_refl, N_atoms) f = torch.sum(A_expanded * exp_terms, dim=-1) return f def _get_spacegroup_callable(self): """ Create a callable for direct summation functions. The direct summation functions expect a callable that takes fractional coordinates with shape (3, N) and returns coordinates with shape (3, N, n_ops). This allows the structure factor summation to: 1. Compute h.r for all (reflection, atom, symop) combinations 2. Sum exp(2*pi*i*h.r) over symmetry operations for each (reflection, atom) 3. Multiply by atom-specific factors and sum over atoms Returns ------- callable Function that applies symmetry operations. """ if self._spacegroup is None: # P1 symmetry - identity operation only def p1_symmetry(coords_3N): # coords_3N: (3, N) -> (3, N, 1) return coords_3N.unsqueeze(2) return p1_symmetry def apply_symmetry(coords_3N): # coords_3N: (3, N) -> coords_N3: (N, 3) coords_N3 = coords_3N.T # Apply symmetry: (N, 3) -> (N, 3, ops) transformed = self._spacegroup.apply(coords_N3) # Reorder to (3, N, ops) # (N, 3, ops) -> (3, N, ops) result = transformed.permute(1, 0, 2) return result return apply_symmetry def _cartesian_to_fractional(self, xyz_cartesian: torch.Tensor) -> torch.Tensor: """ Convert Cartesian coordinates to fractional coordinates. Parameters ---------- xyz_cartesian : torch.Tensor Cartesian coordinates with shape (N, 3). Returns ------- torch.Tensor Fractional coordinates with shape (N, 3). """ if self._cell is None: raise RuntimeError("Cell not set. Call set_cell_and_spacegroup() first.") # Use Cell's to_fractional method via inv_fractional_matrix # fractional = cartesian @ inv_frac_matrix.T return torch.matmul(xyz_cartesian, self.inv_fractional_matrix.T) # ========================================================================= # Structure Factor Computation # =========================================================================
[docs] def compute_structure_factors( self, hkl: torch.Tensor, xyz_iso: torch.Tensor, adp_iso: torch.Tensor, occ_iso: torch.Tensor, A_iso: torch.Tensor, B_iso: torch.Tensor, xyz_aniso: Optional[torch.Tensor] = None, u_aniso: Optional[torch.Tensor] = None, occ_aniso: Optional[torch.Tensor] = None, A_aniso: Optional[torch.Tensor] = None, B_aniso: Optional[torch.Tensor] = None, apply_symmetry: bool = True, ) -> Tuple[torch.Tensor, None]: """ Compute structure factors from atomic parameters using direct summation. Uses "late symmetry" approach (same as SfFFT): first computes P1 structure factors at symmetry-equivalent HKLs, then combines them with phase shifts. The symmetry formula is: F_sym(h) = Σ_ops exp(2πi h.t) * F_P1(R^T @ h) Parameters ---------- hkl : torch.Tensor Miller indices with shape (n_reflections, 3). xyz_iso : torch.Tensor Isotropic atom coordinates (Cartesian) with shape (n_iso, 3). adp_iso : torch.Tensor Isotropic ADPs (atomic displacement parameters) with shape (n_iso,). occ_iso : torch.Tensor Isotropic occupancies with shape (n_iso,). A_iso : torch.Tensor ITC92 A parameters for isotropic atoms with shape (n_iso, 5). B_iso : torch.Tensor ITC92 B parameters for isotropic atoms with shape (n_iso, 5). xyz_aniso : torch.Tensor, optional Anisotropic atom coordinates (Cartesian) with shape (n_aniso, 3). u_aniso : torch.Tensor, optional Anisotropic U parameters with shape (n_aniso, 6). occ_aniso : torch.Tensor, optional Anisotropic occupancies with shape (n_aniso,). A_aniso : torch.Tensor, optional ITC92 A parameters for anisotropic atoms with shape (n_aniso, 5). B_aniso : torch.Tensor, optional ITC92 B parameters for anisotropic atoms with shape (n_aniso, 5). apply_symmetry : bool, optional If True, apply crystallographic symmetry. Default is True. Returns ------- sf : torch.Tensor Complex structure factors with shape (n_reflections,). None Second return value is None (for API compatibility with SfFFT). """ if self._cell is None: raise RuntimeError("Cell not set. Call set_cell_and_spacegroup() first.") # Cache atomic parameters for reuse xyz_frac_iso = self._cartesian_to_fractional(xyz_iso) if len(xyz_iso) > 0 else None xyz_frac_aniso = self._cartesian_to_fractional(xyz_aniso) if xyz_aniso is not None and len(xyz_aniso) > 0 else None # No symmetry: compute F_P1 directly if not apply_symmetry or self._spacegroup is None: sf_p1 = self._compute_p1_sf( hkl, xyz_frac_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_frac_aniso, u_aniso, occ_aniso, A_aniso, B_aniso ) return sf_p1, None # Apply late symmetry: F_sym(h) = Σ_ops exp(2πi h.t) * F_P1(R^T @ h) from torchref.base.reciprocal import ( compute_symmetry_equivalent_hkls, compute_translation_phases, ) n_ops = self._spacegroup.n_ops rotation_matrices = self._spacegroup.matrices translations = self._spacegroup.translations # Compute equivalent HKLs: (n_ops, N, 3) equiv_hkls = compute_symmetry_equivalent_hkls(hkl, rotation_matrices) # Compute translation phase shifts: (n_ops, N) phases = compute_translation_phases(hkl, translations) # Compute F_P1 at each equivalent HKL and combine sf_total = torch.zeros(hkl.shape[0], dtype=torch.complex128, device=self.device) for i in range(n_ops): equiv_hkl_i = equiv_hkls[i].float() # (N, 3) sf_p1_i = self._compute_p1_sf( equiv_hkl_i, xyz_frac_iso, adp_iso, occ_iso, A_iso, B_iso, xyz_frac_aniso, u_aniso, occ_aniso, A_aniso, B_aniso ) # Apply phase and accumulate sf_total = sf_total + phases[i] * sf_p1_i return sf_total, None
def _compute_p1_sf( self, hkl: torch.Tensor, xyz_frac_iso: Optional[torch.Tensor], adp_iso: torch.Tensor, occ_iso: torch.Tensor, A_iso: torch.Tensor, B_iso: torch.Tensor, xyz_frac_aniso: Optional[torch.Tensor], u_aniso: Optional[torch.Tensor], occ_aniso: Optional[torch.Tensor], A_aniso: Optional[torch.Tensor], B_aniso: Optional[torch.Tensor], ) -> torch.Tensor: """Compute P1 structure factors (no symmetry expansion). Always passes A/B coefficients to low-level functions so scattering factors are computed in batches, avoiding large (N_refl, N_atoms) tensors. """ # P1 identity def p1_symmetry(coords_3N): return coords_3N.unsqueeze(2) # (3, N) -> (3, N, 1) # Get reciprocal basis matrix and compute scattering vectors recB = self._get_reciprocal_basis_matrix() s_vectors = get_scattering_vectors(hkl, self._cell.data, recB) s = torch.norm(s_vectors, dim=1) sf_total = torch.zeros(hkl.shape[0], dtype=torch.complex128, device=self.device) # Compute isotropic contribution - always pass A/B coefficients if xyz_frac_iso is not None and len(xyz_frac_iso) > 0: sf_iso = iso_structure_factor_torched( hkl=hkl, s=s, xyz_fractional=xyz_frac_iso, occ=occ_iso, scattering_factors=None, adp=adp_iso, spacegroup=p1_symmetry, max_memory_gb=self.max_memory_gb, A=A_iso, B_coeff=B_iso, ) sf_total = sf_total + sf_iso # Compute anisotropic contribution - always pass A/B coefficients if xyz_frac_aniso is not None and len(xyz_frac_aniso) > 0: sf_aniso = aniso_structure_factor_torched( hkl=hkl, s_vector=s_vectors, xyz_fractional=xyz_frac_aniso, occ=occ_aniso, scattering_factors=None, U=u_aniso, spacegroup=p1_symmetry, max_memory_gb=self.max_memory_gb, A=A_aniso, B_coeff=B_aniso, ) sf_total = sf_total + sf_aniso return sf_total # ========================================================================= # Device Movement # =========================================================================
[docs] def reset_cache(self) -> None: """Drop the cached reciprocal-basis matrix; recomputed on next use.""" self._recB = None
[docs] def copy(self) -> "SfDS": """Create a deep copy of this SfDS module. Returns ------- SfDS A new SfDS instance with cloned cell and spacegroup. """ # Clone the cell new_cell = self._cell.clone() if self._cell is not None else None # Copy the spacegroup new_spacegroup = self._spacegroup.copy() if self._spacegroup is not None else None # Create new SfDS with copied components new_ds = SfDS( cell=new_cell, spacegroup=new_spacegroup, dtype_float=self.dtype_float, device=self.device, verbose=self.verbose, max_memory_gb=self.max_memory_gb, ) return new_ds