Source code for torchref.base.scattering.scattering_table

"""
Table-based ITC92 scattering factor lookup.

This module provides fast, vectorized lookup of scattering factor parameters
using a pre-computed table stored as a .pt file. The table eliminates
runtime gemmi dependency for scattering parameter access.

The table supports:
- Neutral atoms for Z=1 to 103 (H to Lr)
- Common ions with various charge states
- Element symbol to atomic number mapping

Example
-------
::

    from torchref.base.scattering.scattering_table import (
        load_scattering_table,
        get_scattering_params_by_z,
        get_element_to_z_mapping,
    )

    # Load the table
    table = load_scattering_table(device='cuda')

    # Get parameters for multiple atoms at once
    z_tensor = torch.tensor([6, 7, 8])  # C, N, O
    A, B = get_scattering_params_by_z(z_tensor)
"""

import os
from typing import Dict, Optional, Tuple

import torch

from torchref.config import get_float_dtype

# Global cache for the loaded table
_TABLE_CACHE: Optional[dict] = None


def _get_table_path() -> str:
    """Get the path to the pre-computed scattering table."""
    from torchref import PATH_TORCHREF_DATA

    return os.path.join(PATH_TORCHREF_DATA, "itc92_scattering_factors.pt")


[docs] def load_scattering_table( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, force_reload: bool = False, ) -> dict: """ Load pre-computed ITC92 scattering factors from .pt file. The table is cached globally after first load for efficiency. Use force_reload=True to reload from disk. Parameters ---------- device : torch.device, optional Device to place tensors on. Default is None (keeps original device). dtype : torch.dtype, optional Data type for floating point tensors. Default is None (keeps original dtype). force_reload : bool, optional Force reload from disk even if cached. Default is False. Returns ------- dict Dictionary containing: - 'A': Tensor(max_z + 1, 5) - neutral A coefficients indexed by Z - 'B': Tensor(max_z + 1, 5) - neutral B coefficients indexed by Z - 'element_to_z': dict mapping element symbols to atomic numbers - 'z_to_element': dict mapping atomic numbers to element symbols - 'ions': dict mapping ion keys to (A, B) tuples - 'metadata': dict with source information Raises ------ FileNotFoundError If the pre-computed table file does not exist. Examples -------- :: table = load_scattering_table(device='cuda', dtype=torch.float32) A = table['A'] # Shape (104, 5) z_to_elem = table['z_to_element'] # {1: 'H', 6: 'C', ...} """ global _TABLE_CACHE if _TABLE_CACHE is not None and not force_reload: table = _TABLE_CACHE else: table_path = _get_table_path() if not os.path.exists(table_path): raise FileNotFoundError( f"Scattering factor table not found at {table_path}. " "Run 'python -m torchref.scripts.generate_scattering_table' to generate it." ) table = torch.load(table_path, map_location="cpu", weights_only=False) _TABLE_CACHE = table # Apply device/dtype transformations if requested if device is not None or dtype is not None: result = {} for key, value in table.items(): if isinstance(value, torch.Tensor): if device is not None: value = value.to(device=device) if dtype is not None and value.is_floating_point(): value = value.to(dtype=dtype) result[key] = value elif key == "ions" and isinstance(value, dict): # Transform ion tensors ions_result = {} for ion_key, (A, B) in value.items(): if device is not None: A = A.to(device=device) B = B.to(device=device) if dtype is not None: A = A.to(dtype=dtype) B = B.to(dtype=dtype) ions_result[ion_key] = (A, B) result[key] = ions_result else: result[key] = value return result return table
[docs] def get_element_to_z_mapping() -> Dict[str, int]: """ Return element symbol to atomic number mapping. This function loads the scattering table if not already cached and returns the element_to_z dictionary. Returns ------- dict Mapping of element symbols to atomic numbers. Example: {'H': 1, 'C': 6, 'N': 7, 'O': 8, ...} Examples -------- :: element_to_z = get_element_to_z_mapping() z_carbon = element_to_z['C'] # Returns 6 """ table = load_scattering_table() return table["element_to_z"]
[docs] def get_z_to_element_mapping() -> Dict[int, str]: """ Return atomic number to element symbol mapping. Returns ------- dict Mapping of atomic numbers to element symbols. Example: {1: 'H', 6: 'C', 7: 'N', 8: 'O', ...} """ table = load_scattering_table() return table["z_to_element"]
[docs] def get_scattering_params_by_z( z_tensor: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Fast vectorized lookup of scattering parameters by atomic number. This function provides efficient batch lookup of ITC92 scattering parameters using tensor indexing. All atoms are looked up in a single operation. Parameters ---------- z_tensor : torch.Tensor Atomic numbers for all atoms, shape (n_atoms,). Values should be in range 1-103. device : torch.device, optional Device to place output tensors on. Default uses z_tensor's device. dtype : torch.dtype, optional Data type for output tensors. Default is torch.float32. Returns ------- A : torch.Tensor ITC92 A parameters (amplitudes) with shape (n_atoms, 5). B : torch.Tensor ITC92 B parameters (widths) with shape (n_atoms, 5). Examples -------- :: # Get parameters for C, N, O atoms z = torch.tensor([6, 7, 8, 6, 6, 7]) A, B = get_scattering_params_by_z(z) # A.shape == (6, 5), B.shape == (6, 5) # Use on GPU z_gpu = z.cuda() A, B = get_scattering_params_by_z(z_gpu, device='cuda') """ if device is None: device = z_tensor.device if dtype is None: dtype = get_float_dtype() table = load_scattering_table(device=device, dtype=dtype) # Ensure z_tensor is on the right device and is long type for indexing z_idx = z_tensor.to(device=device, dtype=torch.long) # Direct tensor indexing for vectorized lookup A = table["A"][z_idx] B = table["B"][z_idx] return A, B
[docs] def get_scattering_params_for_ion( element: str, charge: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """ Get scattering parameters for a specific ion. Parameters ---------- element : str Element symbol (e.g., 'Fe', 'O'). charge : int Ionic charge (positive or negative integer). device : torch.device, optional Device to place tensors on. dtype : torch.dtype, optional Data type for tensors. Returns ------- tuple or None (A, B) tensors of shape (5,), or None if ion not found. Examples -------- :: # Get Fe2+ parameters A, B = get_scattering_params_for_ion('Fe', 2) # Get O2- parameters A, B = get_scattering_params_for_ion('O', -2) """ if dtype is None: dtype = get_float_dtype() table = load_scattering_table(device=device, dtype=dtype) # Build ion key if charge > 0: key = f"{element}{charge}+" elif charge < 0: key = f"{element}{abs(charge)}-" else: # For neutral, use Z-based lookup element_to_z = table["element_to_z"] z = element_to_z.get(element) if z is None: return None A = table["A"][z] B = table["B"][z] return A, B ions = table.get("ions", {}) if key in ions: return ions[key] return None
[docs] def elements_to_z(elements: list, normalize: bool = True) -> torch.Tensor: """ Convert a list of element symbols to atomic numbers. Parameters ---------- elements : list of str Element symbols (e.g., ['C', 'N', 'O', 'C', 'C']). normalize : bool, optional If True, normalize element names (strip whitespace, capitalize). Default is True. Returns ------- torch.Tensor Tensor of atomic numbers with shape (n_atoms,). Unknown elements are assigned Z=0. Examples -------- :: z = elements_to_z(['C', 'N', 'O']) # Returns tensor([6, 7, 8]) z = elements_to_z([' c ', 'N', 'O'], normalize=True) # Returns tensor([6, 7, 8]) """ element_to_z = get_element_to_z_mapping() z_values = [] for elem in elements: if normalize: elem = elem.strip().capitalize() z = element_to_z.get(elem, 0) z_values.append(z) return torch.tensor(z_values, dtype=torch.int32)