"""
Anisotropic structure factor calculations.
Functions for computing structure factors from atomic models with
anisotropic (non-spherical) atomic displacement parameters.
"""
from typing import Optional
import numpy as np
import torch
def _estimate_batch_size_aniso(n_refl: int, n_atoms: int, n_ops: int, max_memory_gb: float) -> int:
"""
Estimate optimal batch size for anisotropic calculations.
Anisotropic has additional memory for U matrix operations.
Conservative estimate: ~80 bytes per (refl, atom, op) combination.
"""
bytes_per_refl = n_atoms * n_ops * 80
max_bytes = max_memory_gb * 1e9
batch_size = max(1, int(max_bytes / bytes_per_refl))
return min(batch_size, n_refl)
[docs]
def aniso_structure_factor_torched(
hkl, s_vector, xyz_fractional, occ, scattering_factors, U, spacegroup,
max_memory_gb: Optional[float] = None,
A: Optional[torch.Tensor] = None,
B_coeff: Optional[torch.Tensor] = None,
):
"""
Calculate anisotropic structure factors using PyTorch.
Parameters
----------
hkl : torch.Tensor
Miller indices of shape (N_reflections, 3).
s_vector : torch.Tensor
Scattering vectors of shape (N_reflections, 3).
xyz_fractional : torch.Tensor
Fractional coordinates of shape (N_atoms, 3).
occ : torch.Tensor
Occupancies of shape (N_atoms,).
scattering_factors : torch.Tensor or None
Atomic scattering factors of shape (N_reflections, N_atoms).
If None, must provide A and B_coeff to compute them in batches.
U : torch.Tensor
Anisotropic displacement parameters of shape (N_atoms, 6).
spacegroup : callable
Space group symmetry operator function.
max_memory_gb : float, optional
Maximum memory to use in GB. If None, no batching is applied.
A : torch.Tensor, optional
ITC92 A coefficients (N_atoms, 5) for computing scattering factors.
B_coeff : torch.Tensor, optional
ITC92 B coefficients (N_atoms, 5) for computing scattering factors.
Returns
-------
torch.Tensor
Complex structure factors of shape (N_reflections,).
"""
# Apply spacegroup to get symmetry-expanded coordinates
xyz_expanded = spacegroup(xyz_fractional.T) # (3, N_atoms, n_ops)
fractional_shape = xyz_expanded.shape
n_ops = fractional_shape[2] if len(fractional_shape) > 2 else 1
n_atoms = fractional_shape[1]
n_refl = hkl.shape[0]
# Reshape for matrix multiplication
xyz_flat = xyz_expanded.reshape(3, -1) # (3, N_atoms * n_ops)
# Precompute U matrix (doesn't depend on batch)
U_row1 = torch.stack([U[:, 0], U[:, 3], U[:, 4]], dim=0)
U_row2 = torch.stack([U[:, 3], U[:, 1], U[:, 5]], dim=0)
U_row3 = torch.stack([U[:, 4], U[:, 5], U[:, 2]], dim=0)
U_matrix = torch.stack([U_row1, U_row2, U_row3], dim=0) # (3, 3, N_atoms)
# Check if batching is needed
if max_memory_gb is not None:
batch_size = _estimate_batch_size_aniso(n_refl, n_atoms, n_ops, max_memory_gb)
if batch_size < n_refl:
return _aniso_sf_batched(
hkl, s_vector, xyz_flat, fractional_shape, occ, scattering_factors,
U_matrix, batch_size, A, B_coeff
)
# Batching not needed but scattering_factors may be None - compute them
if scattering_factors is None and A is not None and B_coeff is not None:
s_mag = torch.norm(s_vector, dim=1)
scattering_factors = _compute_scattering_factors_batch(s_mag, A, B_coeff)
# No batching - compute all at once
dot_product = torch.matmul(hkl.to(torch.float64), xyz_flat).reshape(
n_refl, n_atoms, -1
)
U_dot_s = torch.einsum("jik,li->jkl", U_matrix, s_vector) # (3, N_atoms, N_refl)
StUS = torch.einsum("li,ikl->lk", s_vector, U_dot_s) # (N_refl, N_atoms)
B = -2 * (np.pi**2) * StUS
exp_B = torch.exp(B)
terms = scattering_factors * exp_B * occ
pidot = 2 * np.pi * dot_product
sin_cos = torch.sum(1j * torch.sin(pidot) + torch.cos(pidot), axis=-1)
return torch.sum(terms * sin_cos, axis=1)
from torchref.base.direct_summation import (
compute_scattering_factors_batch as _compute_scattering_factors_batch,
)
def _aniso_sf_batched(
hkl, s_vector, xyz_flat, fractional_shape, occ, scattering_factors, U_matrix, batch_size,
A=None, B_coeff=None
):
"""
Compute anisotropic structure factors in batches over reflections.
Parameters
----------
hkl : torch.Tensor
Miller indices (N_refl, 3).
s_vector : torch.Tensor
Scattering vectors (N_refl, 3).
xyz_flat : torch.Tensor
Flattened fractional coordinates (3, N_atoms * n_ops).
fractional_shape : tuple
Original shape (3, N_atoms, n_ops).
occ : torch.Tensor
Occupancies (N_atoms,).
scattering_factors : torch.Tensor or None
Scattering factors (N_refl, N_atoms). If None, computed from A/B_coeff.
U_matrix : torch.Tensor
Precomputed U matrix (3, 3, N_atoms).
batch_size : int
Number of reflections per batch.
A : torch.Tensor, optional
ITC92 A coefficients for computing scattering factors.
B_coeff : torch.Tensor, optional
ITC92 B coefficients for computing scattering factors.
Returns
-------
torch.Tensor
Complex structure factors (N_refl,).
"""
n_refl = hkl.shape[0]
n_atoms = fractional_shape[1]
device = hkl.device
# Pre-allocate output tensor to avoid accumulating in list
sf_out = torch.zeros(n_refl, dtype=torch.complex128, device=device)
for start in range(0, n_refl, batch_size):
end = min(start + batch_size, n_refl)
# Slice inputs for this batch
hkl_batch = hkl[start:end]
s_batch = s_vector[start:end]
# Get or compute scattering factors for this batch
if scattering_factors is not None:
sf_batch = scattering_factors[start:end]
else:
s_mag = torch.norm(s_batch, dim=1)
sf_batch = _compute_scattering_factors_batch(s_mag, A, B_coeff)
# Compute for this batch
dot_product = torch.matmul(hkl_batch.to(torch.float64), xyz_flat).reshape(
end - start, n_atoms, -1
)
U_dot_s = torch.einsum("jik,li->jkl", U_matrix, s_batch) # (3, N_atoms, batch)
StUS = torch.einsum("li,ikl->lk", s_batch, U_dot_s) # (batch, N_atoms)
B = -2 * (np.pi**2) * StUS
exp_B = torch.exp(B)
terms = sf_batch * exp_B * occ
pidot = 2 * np.pi * dot_product
sin_cos = torch.sum(1j * torch.sin(pidot) + torch.cos(pidot), axis=-1)
sf_out[start:end] = torch.sum(terms * sin_cos, axis=1)
return sf_out
[docs]
def aniso_structure_factor_torched_no_complex(
hkl, s_vector, fractional_coords, occ, scattering_factors, U, space_group
):
"""
Calculate anisotropic structure factors without complex numbers.
Returns real and imaginary parts as separate rows.
Parameters
----------
hkl : torch.Tensor
Miller indices of shape (N_reflections, 3).
s_vector : torch.Tensor
Scattering vectors of shape (N_reflections, 3).
fractional_coords : torch.Tensor
Fractional coordinates of shape (N_atoms, 3).
occ : torch.Tensor
Occupancies of shape (N_atoms,).
scattering_factors : torch.Tensor
Atomic scattering factors of shape (N_reflections, N_atoms).
U : torch.Tensor
Anisotropic displacement parameters of shape (N_atoms, 6).
space_group : callable
Space group symmetry operator function.
Returns
-------
torch.Tensor
Structure factors as [real, imag] of shape (2, N_reflections).
"""
fractional_coords = space_group(fractional_coords.T)
fractional_shape = fractional_coords.shape
fractional_coords = fractional_coords.reshape(3, -1)
dot_product = torch.matmul(hkl.to(torch.float64), fractional_coords).reshape(
hkl.shape[0], fractional_shape[1], -1
)
U_row1 = torch.stack([U[:, 0], U[:, 3], U[:, 4]], dim=0)
U_row2 = torch.stack([U[:, 3], U[:, 1], U[:, 5]], dim=0)
U_row3 = torch.stack([U[:, 4], U[:, 5], U[:, 2]], dim=0)
U_matrix = torch.stack([U_row1, U_row2, U_row3], dim=0)
U_dot_s = torch.einsum("jik,li->jkl", U_matrix, s_vector) # Shape (3, M, N)
StUS = torch.einsum("li,ikl->lk", s_vector, U_dot_s) # Shape (M, N)
B = -2 * (np.pi**2) * StUS
exp_B = torch.exp(B)
terms = scattering_factors * exp_B * occ
pidot = 2 * np.pi * dot_product
complex_part = torch.sum(torch.sum(torch.sin(pidot), axis=-1) * terms, axis=1)
real_part = torch.sum(torch.sum(torch.cos(pidot), axis=-1) * terms, axis=1)
return torch.vstack((real_part, complex_part))