torchref.base.scattering package
Atomic scattering factor functions.
This submodule provides functions for computing atomic scattering factors using the ITC92 parameterization.
Two approaches are available: 1. Table-based lookup (recommended): Fast, vectorized, no gemmi dependency at runtime 2. Runtime gemmi calls (legacy): Slower, requires gemmi
Example using table lookup:
from torchref.base.scattering import (
load_scattering_table,
get_scattering_params_by_z,
elements_to_z,
)
z = elements_to_z(['C', 'N', 'O'])
A, B = get_scattering_params_by_z(z)
- torchref.base.scattering.get_scattering_factors_unique(atoms, s)[source]
Compute unique scattering factors for a set of atoms.
- Parameters:
atoms (DataFrame-like) – Atoms with ‘element’ and ‘charge’ attributes.
s (array-like) – Scattering vector magnitudes.
- Returns:
Dictionary mapping element symbols to scattering factors.
- Return type:
- torchref.base.scattering.get_scattering_factors(scattering_dict, elements)[source]
Get scattering factors from a pre-computed dictionary.
- Parameters:
- Returns:
Concatenated scattering factors.
- Return type:
- torchref.base.scattering.get_scattering_itc92(df, s)[source]
Get ITC92 scattering factors using gemmi.
- Parameters:
df (DataFrame) – DataFrame with ‘element’ column.
s (torch.Tensor) – Scattering vector magnitudes.
- Returns:
Scattering factors for all atoms.
- Return type:
- torchref.base.scattering.calc_scattering_factors_paramtetrization(parametrization, s, atom_list)[source]
Calculate scattering factors from ITC92 parametrization.
- Parameters:
parametrization (dict) – Dictionary of (A, B, C) tuples by element.
s (torch.Tensor) – Scattering vector magnitudes.
atom_list (list) – List of atom symbols.
- Returns:
Scattering factors.
- Return type:
- torchref.base.scattering.get_parameterization(df)[source]
Get ITC92 parametrization for atoms in a DataFrame.
- Parameters:
df (DataFrame) – DataFrame with ‘element’ and ‘charge’ columns.
- Returns:
Dictionary of parametrization by element.
- Return type:
- torchref.base.scattering.get_parameterization_extended(df)[source]
Extended parametrization function that handles all atoms in a DataFrame.
Creates a dictionary mapping element symbols (and optionally charges) to their ITC92 parameters (A, B, C). This is optimized for FT-based calculations where we need fast access to parametrization without scattering vectors.
- Parameters:
df (pandas.DataFrame) – DataFrame with ‘element’ and ‘charge’ columns
- Returns:
dict – A: torch.Tensor, shape (1, 4) - amplitude coefficients B: torch.Tensor, shape (1, 4) - width coefficients (Ų) C: torch.Tensor, shape (1,) - constant term
- Return type:
- torchref.base.scattering.get_parametrization_for_elements(elements, charges=None)[source]
Get ITC92 parametrization for a list of elements.
Useful for getting parametrization for specific atoms without a full DataFrame.
- torchref.base.scattering.get_parametrization_atom(charge, atom)[source]
Get ITC92 parametrization for a single atom.
- torchref.base.scattering.linear_interpolation(x, x0, x1, y0, y1)[source]
Perform linear interpolation.
- Parameters:
x (array-like) – Query points.
x0 (array-like) – Lower bound x values.
x1 (array-like) – Upper bound x values.
y0 (array-like) – Lower bound y values.
y1 (array-like) – Upper bound y values.
- Returns:
Interpolated y values.
- Return type:
array-like
- torchref.base.scattering.load_scattering_table(device=None, dtype=None, force_reload=False)[source]
Load pre-computed ITC92 scattering factors from .pt file.
The table is cached globally after first load for efficiency. Use force_reload=True to reload from disk.
- Parameters:
device (torch.device, optional) – Device to place tensors on. Default is None (keeps original device).
dtype (torch.dtype, optional) – Data type for floating point tensors. Default is None (keeps original dtype).
force_reload (bool, optional) – Force reload from disk even if cached. Default is False.
- Returns:
Dictionary containing: - ‘A’: Tensor(max_z + 1, 5) - neutral A coefficients indexed by Z - ‘B’: Tensor(max_z + 1, 5) - neutral B coefficients indexed by Z - ‘element_to_z’: dict mapping element symbols to atomic numbers - ‘z_to_element’: dict mapping atomic numbers to element symbols - ‘ions’: dict mapping ion keys to (A, B) tuples - ‘metadata’: dict with source information
- Return type:
- Raises:
FileNotFoundError – If the pre-computed table file does not exist.
Examples
table = load_scattering_table(device='cuda', dtype=torch.float32) A = table['A'] # Shape (104, 5) z_to_elem = table['z_to_element'] # {1: 'H', 6: 'C', ...}
- torchref.base.scattering.get_scattering_params_by_z(z_tensor, device=None, dtype=None)[source]
Fast vectorized lookup of scattering parameters by atomic number.
This function provides efficient batch lookup of ITC92 scattering parameters using tensor indexing. All atoms are looked up in a single operation.
- Parameters:
z_tensor (torch.Tensor) – Atomic numbers for all atoms, shape (n_atoms,). Values should be in range 1-103.
device (torch.device, optional) – Device to place output tensors on. Default uses z_tensor’s device.
dtype (torch.dtype, optional) – Data type for output tensors. Default is torch.float32.
- Returns:
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- Return type:
Examples
# Get parameters for C, N, O atoms z = torch.tensor([6, 7, 8, 6, 6, 7]) A, B = get_scattering_params_by_z(z) # A.shape == (6, 5), B.shape == (6, 5) # Use on GPU z_gpu = z.cuda() A, B = get_scattering_params_by_z(z_gpu, device='cuda')
- torchref.base.scattering.get_element_to_z_mapping()[source]
Return element symbol to atomic number mapping.
This function loads the scattering table if not already cached and returns the element_to_z dictionary.
- Returns:
Mapping of element symbols to atomic numbers. Example: {‘H’: 1, ‘C’: 6, ‘N’: 7, ‘O’: 8, …}
- Return type:
Examples
element_to_z = get_element_to_z_mapping() z_carbon = element_to_z['C'] # Returns 6
- torchref.base.scattering.get_z_to_element_mapping()[source]
Return atomic number to element symbol mapping.
- Returns:
Mapping of atomic numbers to element symbols. Example: {1: ‘H’, 6: ‘C’, 7: ‘N’, 8: ‘O’, …}
- Return type:
- torchref.base.scattering.get_scattering_params_for_ion(element, charge, device=None, dtype=None)[source]
Get scattering parameters for a specific ion.
- Parameters:
element (str) – Element symbol (e.g., ‘Fe’, ‘O’).
charge (int) – Ionic charge (positive or negative integer).
device (torch.device, optional) – Device to place tensors on.
dtype (torch.dtype, optional) – Data type for tensors.
- Returns:
(A, B) tensors of shape (5,), or None if ion not found.
- Return type:
tuple or None
Examples
# Get Fe2+ parameters A, B = get_scattering_params_for_ion('Fe', 2) # Get O2- parameters A, B = get_scattering_params_for_ion('O', -2)
- torchref.base.scattering.elements_to_z(elements, normalize=True)[source]
Convert a list of element symbols to atomic numbers.
- Parameters:
- Returns:
Tensor of atomic numbers with shape (n_atoms,). Unknown elements are assigned Z=0.
- Return type:
Examples
z = elements_to_z(['C', 'N', 'O']) # Returns tensor([6, 7, 8]) z = elements_to_z([' c ', 'N', 'O'], normalize=True) # Returns tensor([6, 7, 8])
- torchref.base.scattering.get_anomalous_correction(element, wavelength)[source]
Get f’ and f’’ for a single element at given wavelength.
Uses the Cromer-Liberman calculation via gemmi. Returns the standard crystallographic f’ and f’’ values as used in International Tables.
- Parameters:
- Returns:
f_prime (float) – Real anomalous correction (electrons)
f_double_prime (float) – Imaginary anomalous correction (electrons)
- Return type:
Examples
>>> f_prime, f_double_prime = get_anomalous_correction('Se', 0.9792) >>> print(f"Se at Se K-edge: f'={f_prime:.2f}, f''={f_double_prime:.2f}")
Notes
The gemmi.cromer_liberman function expects energy in eV, not keV.
- torchref.base.scattering.get_significant_elements(elements, wavelength, threshold=0.5)[source]
Find elements with significant anomalous scattering.
An element is considered significant if |f'| > threshold OR |f''| > threshold. This filtering avoids unnecessary computation for light atoms (C, N, O, H) which have negligible anomalous contributions at typical wavelengths.
- Parameters:
- Returns:
{element: (f_prime, f_double_prime)} for significant elements only
- Return type:
Examples
>>> elements = ['C', 'N', 'O', 'S', 'Fe', 'Zn'] >>> significant = get_significant_elements(elements, wavelength=1.0) >>> print(f"Significant anomalous scatterers: {list(significant.keys())}")
- torchref.base.scattering.get_anomalous_corrections_by_indices(element_list, significant_elements, device, dtype)[source]
Get anomalous corrections and mask for atoms needing correction.
This function creates tensors suitable for vectorized computation of the anomalous correction to structure factors.
- Parameters:
element_list (list of str) – Element symbols for all atoms (length n_atoms)
significant_elements (dict) – {element: (f_prime, f_double_prime)} from get_significant_elements()
device (torch.device) – Device for output tensors
dtype (torch.dtype) – Data type for output tensors
- Returns:
mask (torch.Tensor) – Boolean mask of shape (n_atoms,) - True for atoms needing correction
f_prime (torch.Tensor) – f’ values for significant atoms only (n_significant,)
f_double_prime (torch.Tensor) – f’’ values for significant atoms only (n_significant,)
- Return type:
Examples
>>> elements = ['C', 'C', 'N', 'Fe', 'O', 'Fe'] >>> significant = {'Fe': (-1.2, 3.1)} >>> mask, fp, fdp = get_anomalous_corrections_by_indices( ... elements, significant, torch.device('cpu'), torch.float32 ... ) >>> print(f"Atoms needing correction: {mask.sum().item()}") # 2 (two Fe atoms)