torchref.base.scattering package

Atomic scattering factor functions.

This submodule provides functions for computing atomic scattering factors using the ITC92 parameterization.

Two approaches are available: 1. Table-based lookup (recommended): Fast, vectorized, no gemmi dependency at runtime 2. Runtime gemmi calls (legacy): Slower, requires gemmi

Example using table lookup:

from torchref.base.scattering import (
    load_scattering_table,
    get_scattering_params_by_z,
    elements_to_z,
)

z = elements_to_z(['C', 'N', 'O'])
A, B = get_scattering_params_by_z(z)
torchref.base.scattering.get_scattering_factors_unique(atoms, s)[source]

Compute unique scattering factors for a set of atoms.

Parameters:
  • atoms (DataFrame-like) – Atoms with ‘element’ and ‘charge’ attributes.

  • s (array-like) – Scattering vector magnitudes.

Returns:

Dictionary mapping element symbols to scattering factors.

Return type:

dict

torchref.base.scattering.get_scattering_factors(scattering_dict, elements)[source]

Get scattering factors from a pre-computed dictionary.

Parameters:
  • scattering_dict (dict) – Dictionary of scattering factors by element.

  • elements (list) – List of element symbols.

Returns:

Concatenated scattering factors.

Return type:

torch.Tensor

torchref.base.scattering.get_scattering_itc92(df, s)[source]

Get ITC92 scattering factors using gemmi.

Parameters:
  • df (DataFrame) – DataFrame with ‘element’ column.

  • s (torch.Tensor) – Scattering vector magnitudes.

Returns:

Scattering factors for all atoms.

Return type:

torch.Tensor

torchref.base.scattering.calc_scattering_factors_paramtetrization(parametrization, s, atom_list)[source]

Calculate scattering factors from ITC92 parametrization.

Parameters:
  • parametrization (dict) – Dictionary of (A, B, C) tuples by element.

  • s (torch.Tensor) – Scattering vector magnitudes.

  • atom_list (list) – List of atom symbols.

Returns:

Scattering factors.

Return type:

torch.Tensor

torchref.base.scattering.get_parameterization(df)[source]

Get ITC92 parametrization for atoms in a DataFrame.

Parameters:

df (DataFrame) – DataFrame with ‘element’ and ‘charge’ columns.

Returns:

Dictionary of parametrization by element.

Return type:

dict

torchref.base.scattering.get_parameterization_extended(df)[source]

Extended parametrization function that handles all atoms in a DataFrame.

Creates a dictionary mapping element symbols (and optionally charges) to their ITC92 parameters (A, B, C). This is optimized for FT-based calculations where we need fast access to parametrization without scattering vectors.

Parameters:

df (pandas.DataFrame) – DataFrame with ‘element’ and ‘charge’ columns

Returns:

dict – A: torch.Tensor, shape (1, 4) - amplitude coefficients B: torch.Tensor, shape (1, 4) - width coefficients (Ų) C: torch.Tensor, shape (1,) - constant term

Return type:

{element_str: (A, B, C)}

torchref.base.scattering.get_parametrization_for_elements(elements, charges=None)[source]

Get ITC92 parametrization for a list of elements.

Useful for getting parametrization for specific atoms without a full DataFrame.

Parameters:
  • elements (list of str) – Element symbols (e.g., [‘C’, ‘N’, ‘O’])

  • charges (list of int, optional) – Charges for each element (default: all zeros)

Returns:

dict

Return type:

{element: (A, B, C)}

torchref.base.scattering.get_parametrization_atom(charge, atom)[source]

Get ITC92 parametrization for a single atom.

Parameters:
  • charge (int) – Atomic charge.

  • atom (str) – Element symbol.

Returns:

[A, B] tensors for the atom.

Return type:

list

torchref.base.scattering.linear_interpolation(x, x0, x1, y0, y1)[source]

Perform linear interpolation.

Parameters:
  • x (array-like) – Query points.

  • x0 (array-like) – Lower bound x values.

  • x1 (array-like) – Upper bound x values.

  • y0 (array-like) – Lower bound y values.

  • y1 (array-like) – Upper bound y values.

Returns:

Interpolated y values.

Return type:

array-like

torchref.base.scattering.load_scattering_table(device=None, dtype=None, force_reload=False)[source]

Load pre-computed ITC92 scattering factors from .pt file.

The table is cached globally after first load for efficiency. Use force_reload=True to reload from disk.

Parameters:
  • device (torch.device, optional) – Device to place tensors on. Default is None (keeps original device).

  • dtype (torch.dtype, optional) – Data type for floating point tensors. Default is None (keeps original dtype).

  • force_reload (bool, optional) – Force reload from disk even if cached. Default is False.

Returns:

Dictionary containing: - ‘A’: Tensor(max_z + 1, 5) - neutral A coefficients indexed by Z - ‘B’: Tensor(max_z + 1, 5) - neutral B coefficients indexed by Z - ‘element_to_z’: dict mapping element symbols to atomic numbers - ‘z_to_element’: dict mapping atomic numbers to element symbols - ‘ions’: dict mapping ion keys to (A, B) tuples - ‘metadata’: dict with source information

Return type:

dict

Raises:

FileNotFoundError – If the pre-computed table file does not exist.

Examples

table = load_scattering_table(device='cuda', dtype=torch.float32)
A = table['A']  # Shape (104, 5)
z_to_elem = table['z_to_element']  # {1: 'H', 6: 'C', ...}
torchref.base.scattering.get_scattering_params_by_z(z_tensor, device=None, dtype=None)[source]

Fast vectorized lookup of scattering parameters by atomic number.

This function provides efficient batch lookup of ITC92 scattering parameters using tensor indexing. All atoms are looked up in a single operation.

Parameters:
  • z_tensor (torch.Tensor) – Atomic numbers for all atoms, shape (n_atoms,). Values should be in range 1-103.

  • device (torch.device, optional) – Device to place output tensors on. Default uses z_tensor’s device.

  • dtype (torch.dtype, optional) – Data type for output tensors. Default is torch.float32.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).

Return type:

Tuple[Tensor, Tensor]

Examples

# Get parameters for C, N, O atoms
z = torch.tensor([6, 7, 8, 6, 6, 7])
A, B = get_scattering_params_by_z(z)
# A.shape == (6, 5), B.shape == (6, 5)

# Use on GPU
z_gpu = z.cuda()
A, B = get_scattering_params_by_z(z_gpu, device='cuda')
torchref.base.scattering.get_element_to_z_mapping()[source]

Return element symbol to atomic number mapping.

This function loads the scattering table if not already cached and returns the element_to_z dictionary.

Returns:

Mapping of element symbols to atomic numbers. Example: {‘H’: 1, ‘C’: 6, ‘N’: 7, ‘O’: 8, …}

Return type:

dict

Examples

element_to_z = get_element_to_z_mapping()
z_carbon = element_to_z['C']  # Returns 6
torchref.base.scattering.get_z_to_element_mapping()[source]

Return atomic number to element symbol mapping.

Returns:

Mapping of atomic numbers to element symbols. Example: {1: ‘H’, 6: ‘C’, 7: ‘N’, 8: ‘O’, …}

Return type:

dict

torchref.base.scattering.get_scattering_params_for_ion(element, charge, device=None, dtype=None)[source]

Get scattering parameters for a specific ion.

Parameters:
  • element (str) – Element symbol (e.g., ‘Fe’, ‘O’).

  • charge (int) – Ionic charge (positive or negative integer).

  • device (torch.device, optional) – Device to place tensors on.

  • dtype (torch.dtype, optional) – Data type for tensors.

Returns:

(A, B) tensors of shape (5,), or None if ion not found.

Return type:

tuple or None

Examples

# Get Fe2+ parameters
A, B = get_scattering_params_for_ion('Fe', 2)

# Get O2- parameters
A, B = get_scattering_params_for_ion('O', -2)
torchref.base.scattering.elements_to_z(elements, normalize=True)[source]

Convert a list of element symbols to atomic numbers.

Parameters:
  • elements (list of str) – Element symbols (e.g., [‘C’, ‘N’, ‘O’, ‘C’, ‘C’]).

  • normalize (bool, optional) – If True, normalize element names (strip whitespace, capitalize). Default is True.

Returns:

Tensor of atomic numbers with shape (n_atoms,). Unknown elements are assigned Z=0.

Return type:

torch.Tensor

Examples

z = elements_to_z(['C', 'N', 'O'])
# Returns tensor([6, 7, 8])

z = elements_to_z([' c ', 'N', 'O'], normalize=True)
# Returns tensor([6, 7, 8])
torchref.base.scattering.get_anomalous_correction(element, wavelength)[source]

Get f’ and f’’ for a single element at given wavelength.

Uses the Cromer-Liberman calculation via gemmi. Returns the standard crystallographic f’ and f’’ values as used in International Tables.

Parameters:
  • element (str) – Element symbol (e.g., ‘Fe’, ‘Hg’, ‘Se’)

  • wavelength (float) – X-ray wavelength in Angstroms

Returns:

  • f_prime (float) – Real anomalous correction (electrons)

  • f_double_prime (float) – Imaginary anomalous correction (electrons)

Return type:

Tuple[float, float]

Examples

>>> f_prime, f_double_prime = get_anomalous_correction('Se', 0.9792)
>>> print(f"Se at Se K-edge: f'={f_prime:.2f}, f''={f_double_prime:.2f}")

Notes

The gemmi.cromer_liberman function expects energy in eV, not keV.

torchref.base.scattering.get_significant_elements(elements, wavelength, threshold=0.5)[source]

Find elements with significant anomalous scattering.

An element is considered significant if |f'| > threshold OR |f''| > threshold. This filtering avoids unnecessary computation for light atoms (C, N, O, H) which have negligible anomalous contributions at typical wavelengths.

Parameters:
  • elements (list of str) – List of unique element symbols

  • wavelength (float) – X-ray wavelength in Angstroms

  • threshold (float, optional) – Significance threshold in electrons (default: 0.5)

Returns:

{element: (f_prime, f_double_prime)} for significant elements only

Return type:

dict

Examples

>>> elements = ['C', 'N', 'O', 'S', 'Fe', 'Zn']
>>> significant = get_significant_elements(elements, wavelength=1.0)
>>> print(f"Significant anomalous scatterers: {list(significant.keys())}")
torchref.base.scattering.get_anomalous_corrections_by_indices(element_list, significant_elements, device, dtype)[source]

Get anomalous corrections and mask for atoms needing correction.

This function creates tensors suitable for vectorized computation of the anomalous correction to structure factors.

Parameters:
  • element_list (list of str) – Element symbols for all atoms (length n_atoms)

  • significant_elements (dict) – {element: (f_prime, f_double_prime)} from get_significant_elements()

  • device (torch.device) – Device for output tensors

  • dtype (torch.dtype) – Data type for output tensors

Returns:

  • mask (torch.Tensor) – Boolean mask of shape (n_atoms,) - True for atoms needing correction

  • f_prime (torch.Tensor) – f’ values for significant atoms only (n_significant,)

  • f_double_prime (torch.Tensor) – f’’ values for significant atoms only (n_significant,)

Return type:

Tuple[Tensor, Tensor, Tensor]

Examples

>>> elements = ['C', 'C', 'N', 'Fe', 'O', 'Fe']
>>> significant = {'Fe': (-1.2, 3.1)}
>>> mask, fp, fdp = get_anomalous_corrections_by_indices(
...     elements, significant, torch.device('cpu'), torch.float32
... )
>>> print(f"Atoms needing correction: {mask.sum().item()}")  # 2 (two Fe atoms)

Submodules