Source code for torchref.base.reciprocal.interpolation

"""
Interpolation functions for reciprocal space grids.

These functions enable structure factor extraction at non-integer HKL positions,
which is essential for rotation and translation searches in molecular replacement.
"""

import numpy as np
import torch


[docs] def interpolate_structure_factor_from_grid( reciprocal_grid: torch.Tensor, hkl_float: torch.Tensor, interpolate_amplitude: bool = True, ) -> torch.Tensor: """ Interpolate structure factors from reciprocal space grid at non-integer positions. Parameters ---------- reciprocal_grid : torch.Tensor Complex tensor of shape (Nx, Ny, Nz). hkl_float : torch.Tensor Non-integer HKL positions of shape (N, 3). interpolate_amplitude : bool, optional If True (default), interpolate amplitudes instead of complex values. This avoids phase cancellation issues where linear interpolation of complex numbers with different phases can give incorrect magnitudes (e.g., interpolating F=1 and F=-1 gives 0 instead of 1). Returns ------- torch.Tensor Interpolated structure factors of shape (N,). If interpolate_amplitude=True, returns real-valued amplitudes. If interpolate_amplitude=False, returns complex values (use with caution). Notes ----- For a rotation R applied to the model, the structure factor at hkl becomes F(R^T @ hkl), so you would call this with hkl_float = hkl @ R. WARNING: Complex interpolation (interpolate_amplitude=False) can give incorrect results when neighboring voxels have very different phases. For example, if F1 = A*exp(i*0) and F2 = A*exp(i*π), linear interpolation gives magnitude 0 at the midpoint instead of A. Use interpolate_amplitude=True for rotation functions where only magnitudes matter. """ device = reciprocal_grid.device Nx, Ny, Nz = reciprocal_grid.shape hkl_float = hkl_float.to(device=device, dtype=torch.float32) # Get the 8 corner indices for trilinear interpolation h = hkl_float[:, 0] k = hkl_float[:, 1] l = hkl_float[:, 2] # Floor indices h0 = torch.floor(h).long() k0 = torch.floor(k).long() l0 = torch.floor(l).long() # Ceil indices h1 = h0 + 1 k1 = k0 + 1 l1 = l0 + 1 # Fractional parts (weights) - use float32 for weights hd = h - h0.float() kd = k - k0.float() ld = l - l0.float() # Wrap indices to grid (periodic boundary) h0 = torch.remainder(h0, Nx) h1 = torch.remainder(h1, Nx) k0 = torch.remainder(k0, Ny) k1 = torch.remainder(k1, Ny) l0 = torch.remainder(l0, Nz) l1 = torch.remainder(l1, Nz) # Get values at 8 corners c000 = reciprocal_grid[h0, k0, l0] c001 = reciprocal_grid[h0, k0, l1] c010 = reciprocal_grid[h0, k1, l0] c011 = reciprocal_grid[h0, k1, l1] c100 = reciprocal_grid[h1, k0, l0] c101 = reciprocal_grid[h1, k0, l1] c110 = reciprocal_grid[h1, k1, l0] c111 = reciprocal_grid[h1, k1, l1] if interpolate_amplitude: # Convert to amplitudes before interpolation to avoid phase cancellation a000 = torch.abs(c000) a001 = torch.abs(c001) a010 = torch.abs(c010) a011 = torch.abs(c011) a100 = torch.abs(c100) a101 = torch.abs(c101) a110 = torch.abs(c110) a111 = torch.abs(c111) # Trilinear interpolation of amplitudes a00 = a000 * (1 - ld) + a001 * ld a01 = a010 * (1 - ld) + a011 * ld a10 = a100 * (1 - ld) + a101 * ld a11 = a110 * (1 - ld) + a111 * ld a0 = a00 * (1 - kd) + a01 * kd a1 = a10 * (1 - kd) + a11 * kd result = a0 * (1 - hd) + a1 * hd return result else: # Complex interpolation (original behavior - use with caution) # Convert weights to complex dtype for multiplication dtype = reciprocal_grid.dtype hd = hd.to(dtype) kd = kd.to(dtype) ld = ld.to(dtype) c00 = c000 * (1 - ld) + c001 * ld c01 = c010 * (1 - ld) + c011 * ld c10 = c100 * (1 - ld) + c101 * ld c11 = c110 * (1 - ld) + c111 * ld c0 = c00 * (1 - kd) + c01 * kd c1 = c10 * (1 - kd) + c11 * kd result = c0 * (1 - hd) + c1 * hd return result
[docs] def interpolate_complex_from_grid( reciprocal_grid: torch.Tensor, hkl_float: torch.Tensor, ) -> torch.Tensor: """ Interpolate complex structure factors from reciprocal space grid. Unlike amplitude interpolation, this preserves phase information, which is essential for translation searches where phases are used to compute correlation functions. Parameters ---------- reciprocal_grid : torch.Tensor Complex tensor of shape (Nx, Ny, Nz). hkl_float : torch.Tensor Non-integer HKL positions of shape (N, 3). Returns ------- torch.Tensor Interpolated complex structure factors of shape (N,). Notes ----- This function performs trilinear interpolation of complex values. For rotation-only searches (where only magnitudes matter), use `interpolate_structure_factor_from_grid(interpolate_amplitude=True)` instead. For translation searches, complex interpolation is needed because the translation function depends on the phase relationship between F_obs and F_calc. WARNING: Complex interpolation can give reduced magnitudes when neighboring voxels have very different phases. This is acceptable for translation searches where we're computing correlation functions, but not for rotation searches. """ device = reciprocal_grid.device Nx, Ny, Nz = reciprocal_grid.shape hkl_float = hkl_float.to(device=device, dtype=torch.float32) # Get the 8 corner indices for trilinear interpolation h = hkl_float[:, 0] k = hkl_float[:, 1] l = hkl_float[:, 2] # Floor indices h0 = torch.floor(h).long() k0 = torch.floor(k).long() l0 = torch.floor(l).long() # Ceil indices h1 = h0 + 1 k1 = k0 + 1 l1 = l0 + 1 # Fractional parts (weights) hd = h - h0.float() kd = k - k0.float() ld = l - l0.float() # Wrap indices to grid (periodic boundary) h0 = torch.remainder(h0, Nx) h1 = torch.remainder(h1, Nx) k0 = torch.remainder(k0, Ny) k1 = torch.remainder(k1, Ny) l0 = torch.remainder(l0, Nz) l1 = torch.remainder(l1, Nz) # Get complex values at 8 corners c000 = reciprocal_grid[h0, k0, l0] c001 = reciprocal_grid[h0, k0, l1] c010 = reciprocal_grid[h0, k1, l0] c011 = reciprocal_grid[h0, k1, l1] c100 = reciprocal_grid[h1, k0, l0] c101 = reciprocal_grid[h1, k0, l1] c110 = reciprocal_grid[h1, k1, l0] c111 = reciprocal_grid[h1, k1, l1] # Convert weights to complex dtype for multiplication dtype = reciprocal_grid.dtype hd = hd.to(dtype) kd = kd.to(dtype) ld = ld.to(dtype) # Trilinear interpolation of complex values c00 = c000 * (1 - ld) + c001 * ld c01 = c010 * (1 - ld) + c011 * ld c10 = c100 * (1 - ld) + c101 * ld c11 = c110 * (1 - ld) + c111 * ld c0 = c00 * (1 - kd) + c01 * kd c1 = c10 * (1 - kd) + c11 * kd result = c0 * (1 - hd) + c1 * hd return result
[docs] def trilinear_interpolate_patterson( grid: torch.Tensor, points: torch.Tensor, chunk_size: int = 10_000_000 ) -> torch.Tensor: """ Memory-efficient trilinear interpolation on a 3D grid. Pure torch implementation for GPU acceleration and gradient flow. Replaces scipy.ndimage.map_coordinates for torch tensors. Parameters ---------- grid : torch.Tensor 3D grid of values with shape (nx, ny, nz). points : torch.Tensor Coordinates to sample with shape (n_points, 3). Should be in fractional coordinates [0, 1) for 'wrap' mode. Or batch K, n_points, 3 for multiple batches. chunk_size : int, optional Number of points to process at once. Default is 1M. Returns ------- torch.Tensor Interpolated values with shape (n_points,) or (batch, n_points). Notes ----- Supports automatic differentiation for gradient-based optimization. Uses chunked processing and in-place accumulation to reduce memory. """ original_shape = points.shape[:-1] points = points.reshape(-1, 3) n_total = points.shape[0] nx, ny, nz = grid.shape device = grid.device dtype = grid.dtype result = torch.empty(n_total, device=device, dtype=dtype) for start in range(0, n_total, chunk_size): end = min(start + chunk_size, n_total) pts = points[start:end] % 1.0 # Wrap to [0, 1) # Scale to grid coordinates px = pts[:, 0] * nx py = pts[:, 1] * ny pz = pts[:, 2] * nz # Floor indices x0 = px.long() % nx y0 = py.long() % ny z0 = pz.long() % nz x1 = (x0 + 1) % nx y1 = (y0 + 1) % ny z1 = (z0 + 1) % nz # Fractional parts (weights) xd = (px - px.floor()).to(dtype) yd = (py - py.floor()).to(dtype) zd = (pz - pz.floor()).to(dtype) # Precompute complementary weights xd1 = 1 - xd yd1 = 1 - yd zd1 = 1 - zd # Direct accumulation - avoids storing 8 corner arrays result[start:end] = ( grid[x0, y0, z0] * (xd1 * yd1 * zd1) + grid[x0, y0, z1] * (xd1 * yd1 * zd) + grid[x0, y1, z0] * (xd1 * yd * zd1) + grid[x0, y1, z1] * (xd1 * yd * zd) + grid[x1, y0, z0] * (xd * yd1 * zd1) + grid[x1, y0, z1] * (xd * yd1 * zd) + grid[x1, y1, z0] * (xd * yd * zd1) + grid[x1, y1, z1] * (xd * yd * zd) ) return result.reshape(original_shape)
[docs] def interpolate_for_rotation(hkl, R, cell, reciprocal_space_grid): ''' Interpolate structure factors for rotated HKL positions. This works for all cells. Parameters ---------- hkl : torch.Tensor HKL positions, shape (N, 3). R : torch.Tensor Rotation matrix, shape (B, 3, 3). cell : Cell Unit cell object with reciprocal_basis_matrix. reciprocal_space_grid : torch.Tensor Reciprocal space grid of structure factors, shape (Nx, Ny, Nz). Returns ------- torch.Tensor Interpolated structure factors, shape (B, N) / (N) ''' batched = True if R.dim() == 2: R = R.unsqueeze(0) batched = False rotation_in_s = torch.einsum('ij, bjk -> bik', cell.reciprocal_basis_matrix, R.permute(0,2,1)) rotation_in_hkl = torch.einsum('bij, jk -> bik', rotation_in_s, cell.reciprocal_basis_matrix.inverse()) reoriented_hkl = torch.einsum('aj, bji -> bai', hkl.to(torch.float32), rotation_in_hkl) shape = reoriented_hkl.shape reoriented_hkl = reoriented_hkl.reshape(-1, 3) interpolated = interpolate_structure_factor_from_grid(reciprocal_space_grid, reoriented_hkl).reshape(shape[0], shape[1]) return interpolated if batched else interpolated.squeeze(0)
[docs] def smooth_reciprocal_grid( reciprocal_grid: torch.Tensor, sigma: float, mode: str = "amplitude_phase", ) -> torch.Tensor: """ Apply Gaussian smoothing to a reciprocal space grid using native PyTorch. Parameters ---------- reciprocal_grid : torch.Tensor Complex tensor of shape (Nx, Ny, Nz). sigma : float Standard deviation of the Gaussian kernel in voxel units. mode : str, optional Smoothing mode. Options: - "amplitude_phase": Smooth amplitude and phase separately using circular mean for phases. This avoids phase cancellation while preserving the relationship between amplitude and phase. Default. - "amplitude_only": Smooth only amplitudes, preserve original phases. Useful when phase accuracy is critical. - "complex": Smooth real and imaginary parts separately. WARNING: This can cause phase cancellation when neighboring voxels have different phases (e.g., averaging exp(i*0) and exp(i*π) gives 0). Returns ------- torch.Tensor Smoothed reciprocal space grid of shape (Nx, Ny, Nz). Notes ----- Uses FFT-based convolution for efficient 3D Gaussian smoothing with periodic (wrap) boundary conditions. """ if mode == "complex": return _smooth_complex_cartesian(reciprocal_grid, sigma) elif mode == "amplitude_only": return _smooth_amplitude_only(reciprocal_grid, sigma) elif mode == "amplitude_phase": return _smooth_amplitude_and_phase(reciprocal_grid, sigma) else: raise ValueError(f"Unknown mode: {mode}. Use 'amplitude_phase', 'amplitude_only', or 'complex'.")
def _create_gaussian_kernel_fft(shape: tuple, sigma: float, device: torch.device) -> torch.Tensor: """ Create the FFT of a 3D Gaussian kernel for convolution. Parameters ---------- shape : tuple Shape of the grid (Nx, Ny, Nz). sigma : float Standard deviation of the Gaussian in voxel units. device : torch.device Device to create the kernel on. Returns ------- torch.Tensor FFT of the Gaussian kernel (real tensor since Gaussian is symmetric). """ Nx, Ny, Nz = shape # Create frequency grids (shifted so DC is at center conceptually, but we use fftfreq) fx = torch.fft.fftfreq(Nx, device=device) fy = torch.fft.fftfreq(Ny, device=device) fz = torch.fft.fftfreq(Nz, device=device) # Create 3D frequency grid FX, FY, FZ = torch.meshgrid(fx, fy, fz, indexing='ij') # Gaussian in Fourier space: exp(-2 * pi^2 * sigma^2 * |f|^2) # This is the FT of a Gaussian with std=sigma freq_sq = FX**2 + FY**2 + FZ**2 gaussian_fft = torch.exp(-2.0 * (np.pi * sigma) ** 2 * freq_sq) return gaussian_fft def _convolve_3d_fft(grid: torch.Tensor, kernel_fft: torch.Tensor) -> torch.Tensor: """ Convolve a 3D grid with a kernel using FFT (periodic boundary conditions). Parameters ---------- grid : torch.Tensor Input grid (real or complex). kernel_fft : torch.Tensor FFT of the convolution kernel (real tensor). Returns ------- torch.Tensor Convolved grid with same shape and dtype as input. """ grid_fft = torch.fft.fftn(grid) convolved_fft = grid_fft * kernel_fft convolved = torch.fft.ifftn(convolved_fft) # Return real part if input was real, otherwise complex if not grid.is_complex(): return convolved.real return convolved def _smooth_complex_cartesian(reciprocal_grid: torch.Tensor, sigma: float) -> torch.Tensor: """ Smooth by convolving real and imaginary parts separately. WARNING: This causes phase cancellation when neighboring voxels have different phases. """ kernel_fft = _create_gaussian_kernel_fft( reciprocal_grid.shape, sigma, reciprocal_grid.device ) # Smooth real and imaginary parts separately smoothed_real = _convolve_3d_fft(reciprocal_grid.real, kernel_fft) smoothed_imag = _convolve_3d_fft(reciprocal_grid.imag, kernel_fft) return torch.complex(smoothed_real, smoothed_imag) def _smooth_amplitude_only(reciprocal_grid: torch.Tensor, sigma: float) -> torch.Tensor: """ Smooth only the amplitudes, preserving original phases. """ kernel_fft = _create_gaussian_kernel_fft( reciprocal_grid.shape, sigma, reciprocal_grid.device ) # Get amplitude and phase amplitude = torch.abs(reciprocal_grid) phase = torch.angle(reciprocal_grid) # Smooth only the amplitude smoothed_amplitude = _convolve_3d_fft(amplitude, kernel_fft) # Reconstruct with original phases return smoothed_amplitude * torch.exp(1j * phase) def _smooth_amplitude_and_phase(reciprocal_grid: torch.Tensor, sigma: float) -> torch.Tensor: """ Smooth amplitude and phase separately, using circular mean for phases. This avoids phase cancellation by treating phase as a circular quantity. The circular mean is computed as: atan2(mean(sin(phi)), mean(cos(phi))) which is equivalent to averaging unit vectors and taking their angle. """ kernel_fft = _create_gaussian_kernel_fft( reciprocal_grid.shape, sigma, reciprocal_grid.device ) # Get amplitude and phase amplitude = torch.abs(reciprocal_grid) phase = torch.angle(reciprocal_grid) # Smooth the amplitude smoothed_amplitude = _convolve_3d_fft(amplitude, kernel_fft) # For phase: use amplitude-weighted circular mean # This weights the phase contribution by the amplitude (stronger reflections # contribute more to the smoothed phase) # Circular mean: atan2(sum(w*sin(phi)), sum(w*cos(phi))) # where w is the weight (amplitude in this case) weighted_sin = amplitude * torch.sin(phase) weighted_cos = amplitude * torch.cos(phase) smoothed_weighted_sin = _convolve_3d_fft(weighted_sin, kernel_fft) smoothed_weighted_cos = _convolve_3d_fft(weighted_cos, kernel_fft) smoothed_phase = torch.atan2(smoothed_weighted_sin, smoothed_weighted_cos) # Reconstruct complex values return smoothed_amplitude * torch.exp(1j * smoothed_phase)