Source code for torchref.base.direct_summation.isotropic

"""
Isotropic structure factor calculations.

Functions for computing structure factors from atomic models with
isotropic (spherical) atomic displacement parameters.
"""

from typing import Optional

import numpy as np
import torch

from torchref.config import dtypes


def _estimate_batch_size(n_refl: int, n_atoms: int, n_ops: int, max_memory_gb: float) -> int:
    """
    Estimate optimal batch size based on memory constraints.

    Main memory consumers per reflection:
    - dot_product: (N_atoms * n_ops) float32 = 4 bytes each
    - pidot: same size
    - sin_cos: (N_atoms * n_ops) complex64 = 8 bytes each
    - terms: (N_atoms) float32 = 4 bytes each
    - scattering_factors slice: (N_atoms) float32 = 4 bytes each

    Conservative estimate: ~50 bytes per (refl, atom, op) combination.
    """
    bytes_per_refl = n_atoms * n_ops * 50  # Conservative estimate
    max_bytes = max_memory_gb * 1e9
    batch_size = max(1, int(max_bytes / bytes_per_refl))
    return min(batch_size, n_refl)


[docs] def iso_structure_factor_torched( hkl, s, xyz_fractional, occ, scattering_factors, adp, spacegroup, max_memory_gb: Optional[float] = None, A: Optional[torch.Tensor] = None, B_coeff: Optional[torch.Tensor] = None, ): """ Calculate isotropic structure factors using PyTorch. Parameters ---------- hkl : torch.Tensor Miller indices of shape (N_reflections, 3). s : torch.Tensor Scattering vector magnitudes of shape (N_reflections,). xyz_fractional : torch.Tensor Fractional coordinates of shape (N_atoms, 3). occ : torch.Tensor Occupancies of shape (N_atoms,). scattering_factors : torch.Tensor or None Atomic scattering factors of shape (N_reflections, N_atoms). If None, must provide A and B_coeff to compute them in batches. adp : torch.Tensor Atomic displacement parameters (isotropic) of shape (N_atoms,). spacegroup : callable Space group symmetry operator function. max_memory_gb : float, optional Maximum memory to use in GB. If None, no batching is applied. A : torch.Tensor, optional ITC92 A coefficients (N_atoms, 5) for computing scattering factors. Required if scattering_factors is None and batching is needed. B_coeff : torch.Tensor, optional ITC92 B coefficients (N_atoms, 5) for computing scattering factors. Required if scattering_factors is None and batching is needed. Returns ------- torch.Tensor Complex structure factors of shape (N_reflections,). """ # Apply spacegroup to get symmetry-expanded coordinates xyz_expanded = spacegroup(xyz_fractional.T) # (3, N_atoms, n_ops) fractional_shape = xyz_expanded.shape n_ops = fractional_shape[2] if len(fractional_shape) > 2 else 1 n_atoms = fractional_shape[1] n_refl = hkl.shape[0] # Reshape for matrix multiplication xyz_flat = xyz_expanded.reshape(3, -1) # (3, N_atoms * n_ops) # Precompute ADP terms (doesn't depend on batch) adp_row = adp.reshape(1, -1) # (1, N_atoms) # Check if batching is needed if max_memory_gb is not None: batch_size = _estimate_batch_size(n_refl, n_atoms, n_ops, max_memory_gb) if batch_size < n_refl: return _iso_sf_batched( hkl, s, xyz_flat, fractional_shape, occ, scattering_factors, adp_row, batch_size, A, B_coeff ) # Batching not needed but scattering_factors may be None - compute them if scattering_factors is None and A is not None and B_coeff is not None: scattering_factors = _compute_scattering_factors_batch(s, A, B_coeff) # No batching - compute all at once dot_product = torch.matmul(hkl.to(dtypes.float), xyz_flat).reshape( n_refl, n_atoms, -1 ) s_col = s.reshape(-1, 1) B = -adp_row * (s_col**2) / 4 exp_B = torch.exp(B) terms = scattering_factors * exp_B * occ pidot = 2 * np.pi * dot_product sin_cos = torch.sum(1j * torch.sin(pidot) + torch.cos(pidot), axis=-1) return torch.sum(terms * sin_cos, axis=1)
from torchref.base.direct_summation import ( compute_scattering_factors_batch as _compute_scattering_factors_batch, ) def _iso_sf_batched( hkl, s, xyz_flat, fractional_shape, occ, scattering_factors, adp_row, batch_size, A=None, B_coeff=None ): """ Compute isotropic structure factors in batches over reflections. Parameters ---------- hkl : torch.Tensor Miller indices (N_refl, 3). s : torch.Tensor Scattering vector magnitudes (N_refl,). xyz_flat : torch.Tensor Flattened fractional coordinates (3, N_atoms * n_ops). fractional_shape : tuple Original shape (3, N_atoms, n_ops). occ : torch.Tensor Occupancies (N_atoms,). scattering_factors : torch.Tensor or None Scattering factors (N_refl, N_atoms). If None, computed from A/B_coeff. adp_row : torch.Tensor ADP values reshaped to (1, N_atoms). batch_size : int Number of reflections per batch. A : torch.Tensor, optional ITC92 A coefficients for computing scattering factors. B_coeff : torch.Tensor, optional ITC92 B coefficients for computing scattering factors. Returns ------- torch.Tensor Complex structure factors (N_refl,). """ n_refl = hkl.shape[0] n_atoms = fractional_shape[1] device = hkl.device # Pre-allocate output tensor to avoid accumulating in list sf_out = torch.zeros(n_refl, dtype=torch.complex128, device=device) for start in range(0, n_refl, batch_size): end = min(start + batch_size, n_refl) # Slice inputs for this batch hkl_batch = hkl[start:end] s_batch = s[start:end] # Get or compute scattering factors for this batch if scattering_factors is not None: sf_batch = scattering_factors[start:end] else: sf_batch = _compute_scattering_factors_batch(s_batch, A, B_coeff) # Compute for this batch dot_product = torch.matmul(hkl_batch.to(dtypes.float), xyz_flat).reshape( end - start, n_atoms, -1 ) s_col = s_batch.reshape(-1, 1) B = -adp_row * (s_col**2) / 4 exp_B = torch.exp(B) terms = sf_batch * exp_B * occ pidot = 2 * np.pi * dot_product sin_cos = torch.sum(1j * torch.sin(pidot) + torch.cos(pidot), axis=-1) sf_out[start:end] = torch.sum(terms * sin_cos, axis=1) return sf_out
[docs] def iso_structure_factor_torched_no_complex( hkl, s, fractional_coords, occ, scattering_factors, tempfactor, space_group ): """ Calculate isotropic structure factors without complex numbers. Returns real and imaginary parts as separate rows. Parameters ---------- hkl : torch.Tensor Miller indices of shape (N_reflections, 3). s : torch.Tensor Scattering vector magnitudes of shape (N_reflections,). fractional_coords : torch.Tensor Fractional coordinates of shape (N_atoms, 3). occ : torch.Tensor Occupancies of shape (N_atoms,). scattering_factors : torch.Tensor Atomic scattering factors of shape (N_reflections, N_atoms). tempfactor : torch.Tensor Isotropic temperature factors (B-factors) of shape (N_atoms,). space_group : callable Space group symmetry operator function. Returns ------- torch.Tensor Structure factors as [real, imag] of shape (2, N_reflections). """ fractional_coords = space_group(fractional_coords.T) fractional_shape = fractional_coords.shape fractional_coords = fractional_coords.reshape(3, -1) dot_product = torch.matmul(hkl.to(dtypes.float), fractional_coords).reshape( hkl.shape[0], fractional_shape[1], -1 ) tempfactor = tempfactor.reshape(1, -1) s = s.reshape(-1, 1) B = -tempfactor * (s**2) / 4 exp_B = torch.exp(B) terms = scattering_factors * exp_B * occ pidot = 2 * np.pi * dot_product complex_part = torch.sum(torch.sum(torch.sin(pidot), axis=-1) * terms, axis=1) real_part = torch.sum(torch.sum(torch.cos(pidot), axis=-1) * terms, axis=1) return torch.vstack((real_part, complex_part))