Source code for torchref.scripts.generate_scattering_table

#!/usr/bin/env python
"""
Generate pre-computed ITC92 scattering factor table.

This script creates a .pt file containing scattering factor parameters
for all elements (Z=1 to 103) and available ions. The generated file
removes the need for runtime gemmi dependency for scattering parameter lookup.

Usage:
    python -m torchref.scripts.generate_scattering_table

Output:
    torchref/data/itc92_scattering_factors.pt
"""

import os
from pathlib import Path

import pandas as pd
import torch

# gemmi is only required for generation, not at runtime
import gemmi


[docs] def get_element_to_z_from_csv(csv_path: str) -> dict: """ Build element-to-Z mapping from atomic_vdw_radii.csv. Parameters ---------- csv_path : str Path to the CSV file with element data. Returns ------- dict Mapping of element symbol to atomic number. """ df = pd.read_csv(csv_path, comment="#") element_to_z = {} for _, row in df.iterrows(): elem = row["element"].strip() z = int(row["Atomic_Number"]) element_to_z[elem] = z return element_to_z
[docs] def get_itc92_params(element: str, charge: int = 0): """ Get ITC92 scattering parameters for an element with optional charge. Parameters ---------- element : str Element symbol (e.g., 'C', 'Fe'). charge : int, optional Ionic charge. Default is 0. Returns ------- tuple or None (A, B) tensors of shape (5,) each, or None if not available. """ try: sf = gemmi.IT92_get_exact(gemmi.Element(element), charge) # IT92_get_exact returns None if entry not found if sf is None: return None # ITC92 has 4 Gaussians + constant: a1-a4, b1-b4, c # We store as 5 values: [a1, a2, a3, a4, c] and [b1, b2, b3, b4, 0] A = torch.tensor(list(sf.a) + [sf.c], dtype=torch.float32) B = torch.tensor(list(sf.b) + [0.0], dtype=torch.float32) return A, B except Exception: return None
[docs] def generate_scattering_table(output_path: str, csv_path: str, verbose: bool = True): """ Generate the complete scattering factor table. Parameters ---------- output_path : str Path to save the .pt file. csv_path : str Path to atomic_vdw_radii.csv for element-to-Z mapping. verbose : bool, optional Print progress information. Default is True. """ if verbose: print("Generating ITC92 scattering factor table...") # Build element-to-Z mapping from CSV element_to_z = get_element_to_z_from_csv(csv_path) z_to_element = {v: k for k, v in element_to_z.items()} # Find max Z (should be 103 for Lr) max_z = max(element_to_z.values()) if verbose: print(f" Found {len(element_to_z)} elements (Z=1 to {max_z})") # Initialize tensors for neutral atoms: shape (max_z + 1, 5) # Index 0 is unused, indices 1-max_z correspond to elements A_neutral = torch.zeros(max_z + 1, 5, dtype=torch.float32) B_neutral = torch.zeros(max_z + 1, 5, dtype=torch.float32) # Track which elements we successfully got parameters for elements_found = [] elements_missing = [] # Get neutral atom parameters for all elements for elem, z in element_to_z.items(): params = get_itc92_params(elem, charge=0) if params is not None: A_neutral[z] = params[0] B_neutral[z] = params[1] elements_found.append(elem) else: elements_missing.append(elem) if verbose: print(f" Neutral atoms: {len(elements_found)} found, {len(elements_missing)} missing") if elements_missing: print(f" Missing: {elements_missing}") # Build ions dictionary # Common charge states to check for each element # The actual availability depends on the ITC92 tables ions = {} ion_count = 0 # Check a wide range of charge states for each element charge_range = range(-4, 9) # -4 to +8 covers most cases for elem, z in element_to_z.items(): for charge in charge_range: if charge == 0: continue # Skip neutral, already handled params = get_itc92_params(elem, charge) if params is not None: # Key format: "Element+/-charge" e.g., "Fe2+", "O2-" if charge > 0: key = f"{elem}{charge}+" else: key = f"{elem}{abs(charge)}-" ions[key] = (params[0], params[1]) ion_count += 1 if verbose: print(f" Ions: {ion_count} charge states found") if ion_count > 0: # Show some examples examples = list(ions.keys())[:10] print(f" Examples: {examples}") # Build the complete table dictionary table = { "A": A_neutral, "B": B_neutral, "element_to_z": element_to_z, "z_to_element": z_to_element, "ions": ions, "metadata": { "source": "ITC92 via gemmi", "version": "1.0", "max_z": max_z, "n_neutral": len(elements_found), "n_ions": ion_count, }, } # Save the table torch.save(table, output_path) if verbose: file_size = os.path.getsize(output_path) / 1024 print(f" Saved to: {output_path}") print(f" File size: {file_size:.1f} KB") return table
[docs] def main(): """Main entry point for the script.""" # Determine paths script_dir = Path(__file__).parent package_dir = script_dir.parent data_dir = package_dir / "data" csv_path = data_dir / "atomic_vdw_radii.csv" output_path = data_dir / "itc92_scattering_factors.pt" if not csv_path.exists(): raise FileNotFoundError(f"CSV file not found: {csv_path}") generate_scattering_table(str(output_path), str(csv_path), verbose=True) print("\nGeneration complete!")
if __name__ == "__main__": main()