"""
Grid operations for reciprocal space.
Functions for placing structure factors on grids and extracting them,
which are essential for FFT-based calculations.
"""
import math
import torch
[docs]
def place_on_grid(
hkls, structure_factor, grid_size, enforce_hermitian: bool = True
) -> torch.Tensor:
"""
Place structure factors on a reciprocal-space grid.
Vectorized placement of batched structure factors on reciprocal-space grid.
Parameters
----------
hkls : torch.Tensor
Miller indices of shape (N, 3).
structure_factor : torch.Tensor
Structure factors of shape (N,) or (B, N) for batched input.
grid_size : tuple or torch.Tensor
Grid dimensions (Nx, Ny, Nz).
enforce_hermitian : bool, optional
Whether to enforce Hermitian symmetry. Default is True.
Returns
-------
torch.Tensor
Complex tensor grid of structure factors of shape (Nx, Ny, Nz)
or (B, Nx, Ny, Nz) for batched input.
"""
batch_mode = True
if structure_factor.ndim == 1:
structure_factor = structure_factor.unsqueeze(0) # Add batch dimension
batch_mode = False
B = structure_factor.shape[0]
device = structure_factor.device
dtype = structure_factor.dtype
Nx, Ny, Nz = [int(x) for x in grid_size]
# Prepare Miller indices and linear indices
hkls = hkls.to(device=device)
h = hkls[:, 0].to(torch.int64)
k = hkls[:, 1].to(torch.int64)
l = hkls[:, 2].to(torch.int64)
hi = torch.remainder(h, Nx)
ki = torch.remainder(k, Ny)
li = torch.remainder(l, Nz)
lin = (hi * (Ny * Nz) + ki * Nz + li).to(torch.int64) # (N,)
# Vectorized scatter-add to grid
grid = torch.zeros((B, Nx * Ny * Nz), dtype=dtype, device=device)
grid = grid.index_add(1, lin, structure_factor) # (B, Nx*Ny*Nz)
if enforce_hermitian:
hi_sym = torch.remainder(-h, Nx)
ki_sym = torch.remainder(-k, Ny)
li_sym = torch.remainder(-l, Nz)
lin_sym = (hi_sym * (Ny * Nz) + ki_sym * Nz + li_sym).to(torch.int64)
vals_conj = torch.conj(structure_factor)
grid = grid.index_add(1, lin_sym, vals_conj)
grid = grid.view(B, Nx, Ny, Nz)
if not batch_mode:
grid = grid.squeeze(0)
return grid
[docs]
def apply_translation_phase(
F_calc: torch.Tensor,
hkl: torch.Tensor,
translation_frac: torch.Tensor,
) -> torch.Tensor:
"""
Apply translation phase shift to structure factors.
For a translation t in fractional coordinates, the structure factor transforms as:
F'(hkl) = F(hkl) * exp(2πi * hkl · t)
Parameters
----------
F_calc : torch.Tensor
Complex structure factors of shape (N,).
hkl : torch.Tensor
Miller indices of shape (N, 3).
translation_frac : torch.Tensor
Translation vector in fractional coordinates of shape (3,).
Returns
-------
torch.Tensor
Phase-shifted structure factors of shape (N,).
"""
# Compute phase: 2π * hkl · t
phase = 2.0 * math.pi * (hkl.float() @ translation_frac.float())
# Apply phase shift: F * exp(i * phase)
phase_factor = torch.complex(torch.cos(phase), torch.sin(phase))
return F_calc * phase_factor.to(F_calc.dtype)