torchref.base package

Mathematical functions for crystallographic computations.

This module provides PyTorch and NumPy implementations of: - Coordinate transformations (Cartesian <-> fractional) - Structure factor calculations - R-factor computations - French-Wilson intensity conversion - Atomic scattering factors - Grid and reciprocal space utilities

Submodules (New Organization)

coordinates

Coordinate transformation functions (Cartesian <-> fractional).

reciprocal

Reciprocal space calculations (basis, HKL, d-spacing, grid operations).

structure_factors

Structure factor calculations (isotropic, anisotropic, corrections).

electron_density

Electron density map building functions.

fourier

FFT operations and grid utilities.

scattering

Atomic scattering factors (ITC92 parameterization).

alignment

Coordinate alignment and superposition functions.

metrics

R-factor and loss function calculations.

kernels

Optimized GPU/CPU kernels for performance-critical operations.

Legacy Submodules (For Backward Compatibility)

math_torch

PyTorch implementations (deprecated, use domain-specific submodules).

math_numpy

NumPy implementations (deprecated, use domain-specific submodules).

french_wilson

French-Wilson treatment for negative intensities.

get_scattering_factor_torch

Atomic scattering factor calculations (use scattering.itc92 instead).

Example

New-style imports (recommended):

from torchref.base.coordinates import cartesian_to_fractional_torch
from torchref.base.metrics import get_rfactor_torch
from torchref.base.reciprocal import reciprocal_basis_matrix

Legacy imports (still supported):

from torchref.base import cartesian_to_fractional_torch
from torchref.base import math_torch, math_numpy
class torchref.base.FrenchWilson(hkl, cell, space_group='P1', n_bins=60, min_per_bin=40, h_min=-4.0, verbose=1)[source]

Bases: DeviceMixin, Module

PyTorch module for French-Wilson conversion from intensities to structure factors.

Pre-computes all necessary metadata (d-spacings, centric flags, resolution bins) during initialization, so forward pass only needs I and sigma_I.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (n_reflections, 3), integer tensor.

  • cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma] in Å and degrees.

  • space_group (str, int, or gemmi.SpaceGroup, optional) – Space group specification (e.g., ‘P21’, 4, gemmi.SpaceGroup(‘P 21’)). Default is “P1”.

  • n_bins (int, optional) – Number of resolution bins for mean intensity estimation. Default is 60.

  • min_per_bin (int, optional) – Minimum reflections per bin. Default is 40.

  • h_min (float, optional) – Minimum h value for rejection. Default is -4.0.

  • verbose (int, optional) – Verbosity level (0=silent, 1=basic, 2=detailed). Default is 1.

hkl

Miller indices.

Type:

torch.Tensor

d_spacings

Resolution for each reflection in Å.

Type:

torch.Tensor

is_centric

Boolean mask for centric reflections.

Type:

torch.Tensor

Examples

hkl = torch.tensor([[1, 2, 3], [2, 0, 0], [0, 3, 0], [1, 1, 1]])
cell = [50.0, 60.0, 70.0, 90.0, 90.0, 90.0]
fw_module = FrenchWilson(hkl, cell, 'P212121')
I = torch.tensor([100.0, 50.0, 30.0, 200.0])
sigma_I = torch.tensor([10.0, 8.0, 7.0, 15.0])
F, sigma_F = fw_module(I, sigma_I)
__init__(hkl, cell, space_group='P1', n_bins=60, min_per_bin=40, h_min=-4.0, verbose=1)[source]
forward(I, sigma_I)[source]

Apply French-Wilson conversion.

Args:

I: Measured intensities, shape (n_reflections,) sigma_I: Standard deviations of intensities, shape (n_reflections,)

Returns:

F: Structure factor amplitudes, shape (n_reflections,) sigma_F: Standard deviations of F, shape (n_reflections,)

class torchref.base.CachedRadiusMask[source]

Bases: object

Cache the radius mask computation to avoid recomputing for every atom batch.

This eliminates redundant computation when processing multiple atoms with the same voxel size and radius.

Usage

>>> cache = CachedRadiusMask()
>>> offsets = cache.get_offsets(voxel_size, radius_angstrom, device)
param None:

_cache

Internal cache storing computed offsets.

Type:

dict

__init__()[source]
get_offsets(voxel_size, radius_angstrom, device)[source]

Get cached offset grid for given parameters.

Parameters:
  • voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).

  • radius_angstrom (float) – Radius in Angstroms.

  • device (torch.device) – Device for the output tensor.

Returns:

Voxel offsets within radius, shape (N_voxels, 3).

Return type:

torch.Tensor

class torchref.base.ReciprocalSymmetryExtractor(hkl, symmetry, grid_shape, device=None)[source]

Bases: DeviceMixin

Class-based interface for reciprocal space symmetry extraction.

This provides a more efficient interface when computing structure factors multiple times with the same symmetry and HKLs (e.g., during refinement). Precomputes equivalent HKLs, phase factors, and flat grid indices so that each call reduces to a single gather + multiply + sum (~3 GPU kernels instead of ~28).

Parameters:
  • hkl (torch.Tensor, shape (N, 3)) – Target Miller indices.

  • symmetry (SpaceGroup) – SpaceGroup object containing rotation matrices and translations.

  • grid_shape (tuple of int) – Reciprocal grid dimensions (Nx, Ny, Nz).

  • device (torch.device, optional) – Device for computation.

Examples

>>> extractor = ReciprocalSymmetryExtractor(hkl, symmetry, grid_shape=(209, 86, 67))
>>> f_calc = extractor.extract_from_grid(reciprocal_grid)
__init__(hkl, symmetry, grid_shape, device=None)[source]
__call__(density_map)[source]

Compute structure factors from P1 density map.

Parameters:

density_map (torch.Tensor, shape (Nx, Ny, Nz)) – P1 electron density map (NO symmetry applied).

Returns:

Complex structure factors with symmetry applied.

Return type:

torch.Tensor, shape (N,)

extract(density_map)[source]

Extract structure factors from P1 density map.

Parameters:

density_map (torch.Tensor, shape (Nx, Ny, Nz)) – P1 electron density map (NO symmetry applied).

Returns:

Complex structure factors with symmetry applied.

Return type:

torch.Tensor, shape (N,)

extract_from_grid(reciprocal_grid)[source]

Extract structure factors from precomputed reciprocal grid.

Uses precomputed flat indices for a single vectorized gather, avoiding per-symop Python loops and kernel launches.

Parameters:

reciprocal_grid (torch.Tensor, shape (Nx, Ny, Nz)) – Complex reciprocal space grid from FFT.

Returns:

Complex structure factors with symmetry applied.

Return type:

torch.Tensor, shape (N,)

torchref.base.cartesian_to_fractional_torch(xyz, cell, B_inv=None)[source]

Convert Cartesian coordinates to fractional coordinates.

Parameters:
  • xyz (torch.Tensor) – Cartesian coordinates of shape (N, 3).

  • cell (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • B_inv (torch.Tensor, optional) – Inverse fractionalization matrix. If None, it will be calculated from cell.

Returns:

Fractional coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.fractional_to_cartesian_torch(xyz_fractional, cell, B=None)[source]

Convert fractional coordinates to Cartesian coordinates.

Parameters:
  • xyz_fractional (torch.Tensor) – Fractional coordinates of shape (N, 3).

  • cell (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • B (torch.Tensor, optional) – Fractionalization matrix. If None, it will be calculated from cell.

Returns:

Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.get_fractional_matrix(cell)[source]

Calculate the fractional-to-Cartesian transformation matrix.

Constructs the matrix B that transforms fractional coordinates to Cartesian coordinates based on the unit cell parameters.

Parameters:

cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

3x3 transformation matrix B such that cart = frac @ B.T.

Return type:

torch.Tensor

torchref.base.get_inv_fractional_matrix_torch(cell)[source]

Calculate the Cartesian-to-fractional transformation matrix (PyTorch version).

Computes the inverse of the fractional matrix for converting Cartesian coordinates to fractional coordinates.

Parameters:

cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

3x3 inverse transformation matrix B_inv such that frac = cart @ B_inv.T.

Return type:

torch.Tensor

torchref.base.cartesian_to_fractional(xyz, cell)[source]

Convert Cartesian coordinates to fractional coordinates.

Parameters:
  • xyz (numpy.ndarray) – Cartesian coordinates with shape (N, 3).

  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

Fractional coordinates with shape (N, 3).

Return type:

numpy.ndarray

torchref.base.fractional_to_cartesian(xyz_fractional, cell)[source]

Convert fractional coordinates to Cartesian coordinates.

Parameters:
  • xyz_fractional (numpy.ndarray) – Fractional coordinates with shape (N, 3).

  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

Cartesian coordinates with shape (N, 3).

Return type:

numpy.ndarray

torchref.base.get_inv_fractional_matrix(cell)[source]

Calculate the Cartesian-to-fractional transformation matrix.

Computes the inverse of the fractional matrix for converting Cartesian coordinates to fractional coordinates.

Parameters:

cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

3x3 inverse transformation matrix B_inv such that frac = cart @ B_inv.T.

Return type:

numpy.ndarray

torchref.base.convert_coords_to_fractional(df, cell)[source]

Convert coordinates from a DataFrame to fractional coordinates.

Extracts x, y, z columns from a DataFrame and converts them from Cartesian to fractional coordinates.

Parameters:
  • df (pandas.DataFrame) – DataFrame containing ‘x’, ‘y’, ‘z’ columns with Cartesian coordinates.

  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

Fractional coordinates with shape (N, 3).

Return type:

numpy.ndarray

torchref.base.smallest_diff(diff, inv_frac_matrix, frac_matrix)[source]

Compute minimum image squared distances with periodic boundary conditions.

Parameters:
  • diff (torch.Tensor) – Difference vectors of shape (…, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Squared distances with shape (…).

Return type:

torch.Tensor

torchref.base.smallest_diff_aniso(diff, inv_frac_matrix, frac_matrix)[source]

Compute minimum image difference vectors for anisotropic calculations.

Parameters:
  • diff (torch.Tensor) – Difference vectors of shape (…, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Signed difference vectors with shape (…, 3). Note: For anisotropic calculations, the signed vectors are needed to correctly compute the quadratic form r^T × B^(-1) × r with off-diagonal U tensor terms.

Return type:

torch.Tensor

torchref.base.reciprocal_basis_matrix(cell)[source]

Compute the reciprocal space basis matrix from unit cell parameters.

Parameters:

cell (torch.Tensor) – Cell parameters [a, b, c, alpha, beta, gamma].

Returns:

Reciprocal basis matrix of shape (3, 3) with a*, b*, c* as rows.

Return type:

torch.Tensor

torchref.base.reciprocal_basis_matrix_numpy(cell)[source]

Calculate the reciprocal basis matrix from unit cell parameters (NumPy version).

Computes the reciprocal space basis vectors (a*, b*, c*) that define the transformation from Miller indices to scattering vectors.

Parameters:

cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

3x3 matrix containing reciprocal basis vectors as rows [a*, b*, c*].

Return type:

numpy.ndarray

torchref.base.get_scattering_vectors(hkl, cell, recB=None)[source]

Calculate scattering vectors from Miller indices.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • cell (torch.Tensor) – Cell parameters [a, b, c, alpha, beta, gamma].

  • recB (torch.Tensor, optional) – Pre-computed reciprocal basis matrix of shape (3, 3).

Returns:

Scattering vectors of shape (N, 3).

Return type:

torch.Tensor

torchref.base.get_scattering_vectors_numpy(hkl, cell)[source]

Calculate scattering vectors from Miller indices and unit cell (NumPy version).

Transforms Miller indices to reciprocal space scattering vectors using the reciprocal basis matrix.

Parameters:
  • hkl (numpy.ndarray or list) – Miller indices with shape (N, 3).

  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

Scattering vectors in reciprocal space with shape (N, 3).

Return type:

numpy.ndarray

torchref.base.get_s(hkl, cell)[source]

Calculate the magnitude of scattering vectors for given Miller indices.

Computes |s| = 1/d where d is the interplanar spacing for each reflection.

Parameters:
  • hkl (numpy.ndarray) – Miller indices with shape (N, 3).

  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

Magnitude of scattering vectors with shape (N,).

Return type:

numpy.ndarray

torchref.base.get_d_spacing(hkl, cell, recB=None)[source]

Calculate d-spacing from Miller indices.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • cell (torch.Tensor) – Cell parameters [a, b, c, alpha, beta, gamma].

  • recB (torch.Tensor, optional) – Pre-computed reciprocal basis matrix of shape (3, 3).

Returns:

D-spacing values of shape (N,) in Angstroms.

Return type:

torch.Tensor

torchref.base.compute_d_spacing_batch(hkl, cell, recB=None)[source]

Compute d-spacing for a batch of Miller indices.

Wrapper around get_d_spacing for convenience.

Parameters:
Returns:

D-spacing values in Angstroms.

Return type:

torch.Tensor, shape (N,)

torchref.base.generate_possible_hkl(cell, d_min, device=None)[source]

Generate all possible Miller indices within a resolution limit.

Creates a complete set of (h, k, l) indices where the d-spacing is greater than or equal to d_min.

Parameters:
  • cell (torch.Tensor, shape (6,)) – Unit cell parameters [a, b, c, alpha, beta, gamma] in Angstroms and degrees.

  • d_min (float) – High resolution limit in Angstroms (minimum d-spacing).

  • device (torch.device, optional) – Device for computation. If None, uses cell’s device.

Returns:

All Miller indices with d-spacing >= d_min.

Return type:

torch.Tensor, shape (M, 3), dtype int32

Examples

import torch
cell = torch.tensor([50.0, 60.0, 70.0, 90.0, 90.0, 90.0])
hkl = generate_possible_hkl(cell, d_min=2.0)
print(f"Generated {len(hkl)} reflections")
torchref.base.place_on_grid(hkls, structure_factor, grid_size, enforce_hermitian=True)[source]

Place structure factors on a reciprocal-space grid.

Vectorized placement of batched structure factors on reciprocal-space grid.

Parameters:
  • hkls (torch.Tensor) – Miller indices of shape (N, 3).

  • structure_factor (torch.Tensor) – Structure factors of shape (N,) or (B, N) for batched input.

  • grid_size (tuple or torch.Tensor) – Grid dimensions (Nx, Ny, Nz).

  • enforce_hermitian (bool, optional) – Whether to enforce Hermitian symmetry. Default is True.

Returns:

Complex tensor grid of structure factors of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz) for batched input.

Return type:

torch.Tensor

torchref.base.extract_structure_factor_from_grid(reciprocal_grid, hkls)[source]

Extract structure factors from reciprocal space grid at given Miller indices.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz).

  • hkls (torch.Tensor) – Miller indices of shape (N, 3).

Returns:

Structure factors of shape (N,) or (B, N) for batched input.

Return type:

torch.Tensor

torchref.base.apply_translation_phase(F_calc, hkl, translation_frac)[source]

Apply translation phase shift to structure factors.

For a translation t in fractional coordinates, the structure factor transforms as: F’(hkl) = F(hkl) * exp(2πi * hkl · t)

Parameters:
  • F_calc (torch.Tensor) – Complex structure factors of shape (N,).

  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • translation_frac (torch.Tensor) – Translation vector in fractional coordinates of shape (3,).

Returns:

Phase-shifted structure factors of shape (N,).

Return type:

torch.Tensor

torchref.base.interpolate_structure_factor_from_grid(reciprocal_grid, hkl_float, interpolate_amplitude=True)[source]

Interpolate structure factors from reciprocal space grid at non-integer positions.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz).

  • hkl_float (torch.Tensor) – Non-integer HKL positions of shape (N, 3).

  • interpolate_amplitude (bool, optional) – If True (default), interpolate amplitudes instead of complex values. This avoids phase cancellation issues where linear interpolation of complex numbers with different phases can give incorrect magnitudes (e.g., interpolating F=1 and F=-1 gives 0 instead of 1).

Returns:

Interpolated structure factors of shape (N,). If interpolate_amplitude=True, returns real-valued amplitudes. If interpolate_amplitude=False, returns complex values (use with caution).

Return type:

torch.Tensor

Notes

For a rotation R applied to the model, the structure factor at hkl becomes F(R^T @ hkl), so you would call this with hkl_float = hkl @ R.

WARNING: Complex interpolation (interpolate_amplitude=False) can give incorrect results when neighboring voxels have very different phases. For example, if F1 = A*exp(i*0) and F2 = A*exp(i*π), linear interpolation gives magnitude 0 at the midpoint instead of A. Use interpolate_amplitude=True for rotation functions where only magnitudes matter.

torchref.base.interpolate_complex_from_grid(reciprocal_grid, hkl_float)[source]

Interpolate complex structure factors from reciprocal space grid.

Unlike amplitude interpolation, this preserves phase information, which is essential for translation searches where phases are used to compute correlation functions.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz).

  • hkl_float (torch.Tensor) – Non-integer HKL positions of shape (N, 3).

Returns:

Interpolated complex structure factors of shape (N,).

Return type:

torch.Tensor

Notes

This function performs trilinear interpolation of complex values. For rotation-only searches (where only magnitudes matter), use interpolate_structure_factor_from_grid(interpolate_amplitude=True) instead.

For translation searches, complex interpolation is needed because the translation function depends on the phase relationship between F_obs and F_calc.

WARNING: Complex interpolation can give reduced magnitudes when neighboring voxels have very different phases. This is acceptable for translation searches where we’re computing correlation functions, but not for rotation searches.

torchref.base.trilinear_interpolate_patterson(grid, points, chunk_size=10000000)[source]

Memory-efficient trilinear interpolation on a 3D grid.

Pure torch implementation for GPU acceleration and gradient flow. Replaces scipy.ndimage.map_coordinates for torch tensors.

Parameters:
  • grid (torch.Tensor) – 3D grid of values with shape (nx, ny, nz).

  • points (torch.Tensor) – Coordinates to sample with shape (n_points, 3). Should be in fractional coordinates [0, 1) for ‘wrap’ mode. Or batch K, n_points, 3 for multiple batches.

  • chunk_size (int, optional) – Number of points to process at once. Default is 1M.

Returns:

Interpolated values with shape (n_points,) or (batch, n_points).

Return type:

torch.Tensor

Notes

Supports automatic differentiation for gradient-based optimization. Uses chunked processing and in-place accumulation to reduce memory.

torchref.base.compute_symmetry_equivalent_hkls(hkl, rotation_matrices)[source]

Compute symmetry-equivalent HKLs: h’ = R^T @ h for each operation.

In reciprocal space, Miller indices transform as h’ = h @ R (or equivalently h’ = R^T @ h when treating h as a column vector).

Parameters:
  • hkl (torch.Tensor, shape (N, 3)) – Miller indices.

  • rotation_matrices (torch.Tensor, shape (n_ops, 3, 3)) – Real-space rotation matrices. These are transposed internally for reciprocal space transformation.

Returns:

Equivalent HKLs for each symmetry operation.

Return type:

torch.Tensor, shape (n_ops, N, 3)

torchref.base.compute_translation_phases(hkl, translations)[source]

Compute phase shifts: exp(2*pi*i * h.t) for each operation.

The translation component of a symmetry operation causes a phase shift in the structure factor.

Parameters:
Returns:

Complex phase factors exp(2*pi*i * h.t).

Return type:

torch.Tensor, shape (n_ops, N)

torchref.base.extract_structure_factors_with_symmetry(reciprocal_grid, hkl, rotation_matrices, translations)[source]

Extract structure factors with symmetry applied in reciprocal space.

This is the main function that replaces the MapSymmetry approach for structure factor extraction. Instead of symmetrizing the density map and then extracting F(hkl), we extract F at all symmetry-equivalent positions and sum with phases.

Parameters:
Returns:

Complex structure factors with symmetry applied.

Return type:

torch.Tensor, shape (N,)

torchref.base.interpolate_for_rotation(hkl, R, cell, reciprocal_space_grid)[source]

Interpolate structure factors for rotated HKL positions. This works for all cells. :param hkl: HKL positions, shape (N, 3). :type hkl: torch.Tensor :param R: Rotation matrix, shape (B, 3, 3). :type R: torch.Tensor :param cell: Unit cell object with reciprocal_basis_matrix. :type cell: Cell :param reciprocal_space_grid: Reciprocal space grid of structure factors, shape (Nx, Ny, Nz). :type reciprocal_space_grid: torch.Tensor

Returns:

Interpolated structure factors, shape (B, N) / (N)

Return type:

torch.Tensor

torchref.base.smooth_reciprocal_grid(reciprocal_grid, sigma, mode='amplitude_phase')[source]

Apply Gaussian smoothing to a reciprocal space grid using native PyTorch.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz).

  • sigma (float) – Standard deviation of the Gaussian kernel in voxel units.

  • mode (str, optional) –

    Smoothing mode. Options: - “amplitude_phase”: Smooth amplitude and phase separately using

    circular mean for phases. This avoids phase cancellation while preserving the relationship between amplitude and phase. Default.

    • ”amplitude_only”: Smooth only amplitudes, preserve original phases. Useful when phase accuracy is critical.

    • ”complex”: Smooth real and imaginary parts separately. WARNING: This can cause phase cancellation when neighboring voxels have different phases (e.g., averaging exp(i*0) and exp(i*π) gives 0).

Returns:

Smoothed reciprocal space grid of shape (Nx, Ny, Nz).

Return type:

torch.Tensor

Notes

Uses FFT-based convolution for efficient 3D Gaussian smoothing with periodic (wrap) boundary conditions.

torchref.base.iso_structure_factor_torched(hkl, s, xyz_fractional, occ, scattering_factors, adp, spacegroup, max_memory_gb=None, A=None, B_coeff=None)[source]

Calculate isotropic structure factors using PyTorch.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s (torch.Tensor) – Scattering vector magnitudes of shape (N_reflections,).

  • xyz_fractional (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor or None) – Atomic scattering factors of shape (N_reflections, N_atoms). If None, must provide A and B_coeff to compute them in batches.

  • adp (torch.Tensor) – Atomic displacement parameters (isotropic) of shape (N_atoms,).

  • spacegroup (callable) – Space group symmetry operator function.

  • max_memory_gb (float, optional) – Maximum memory to use in GB. If None, no batching is applied.

  • A (torch.Tensor, optional) – ITC92 A coefficients (N_atoms, 5) for computing scattering factors. Required if scattering_factors is None and batching is needed.

  • B_coeff (torch.Tensor, optional) – ITC92 B coefficients (N_atoms, 5) for computing scattering factors. Required if scattering_factors is None and batching is needed.

Returns:

Complex structure factors of shape (N_reflections,).

Return type:

torch.Tensor

torchref.base.iso_structure_factor_torched_no_complex(hkl, s, fractional_coords, occ, scattering_factors, tempfactor, space_group)[source]

Calculate isotropic structure factors without complex numbers.

Returns real and imaginary parts as separate rows.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s (torch.Tensor) – Scattering vector magnitudes of shape (N_reflections,).

  • fractional_coords (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor) – Atomic scattering factors of shape (N_reflections, N_atoms).

  • tempfactor (torch.Tensor) – Isotropic temperature factors (B-factors) of shape (N_atoms,).

  • space_group (callable) – Space group symmetry operator function.

Returns:

Structure factors as [real, imag] of shape (2, N_reflections).

Return type:

torch.Tensor

torchref.base.aniso_structure_factor_torched(hkl, s_vector, xyz_fractional, occ, scattering_factors, U, spacegroup, max_memory_gb=None, A=None, B_coeff=None)[source]

Calculate anisotropic structure factors using PyTorch.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s_vector (torch.Tensor) – Scattering vectors of shape (N_reflections, 3).

  • xyz_fractional (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor or None) – Atomic scattering factors of shape (N_reflections, N_atoms). If None, must provide A and B_coeff to compute them in batches.

  • U (torch.Tensor) – Anisotropic displacement parameters of shape (N_atoms, 6).

  • spacegroup (callable) – Space group symmetry operator function.

  • max_memory_gb (float, optional) – Maximum memory to use in GB. If None, no batching is applied.

  • A (torch.Tensor, optional) – ITC92 A coefficients (N_atoms, 5) for computing scattering factors.

  • B_coeff (torch.Tensor, optional) – ITC92 B coefficients (N_atoms, 5) for computing scattering factors.

Returns:

Complex structure factors of shape (N_reflections,).

Return type:

torch.Tensor

torchref.base.aniso_structure_factor_torched_no_complex(hkl, s_vector, fractional_coords, occ, scattering_factors, U, space_group)[source]

Calculate anisotropic structure factors without complex numbers.

Returns real and imaginary parts as separate rows.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s_vector (torch.Tensor) – Scattering vectors of shape (N_reflections, 3).

  • fractional_coords (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor) – Atomic scattering factors of shape (N_reflections, N_atoms).

  • U (torch.Tensor) – Anisotropic displacement parameters of shape (N_atoms, 6).

  • space_group (callable) – Space group symmetry operator function.

Returns:

Structure factors as [real, imag] of shape (2, N_reflections).

Return type:

torch.Tensor

torchref.base.anharmonic_correction(hkl, c)[source]

Apply anharmonic (third-order) correction to structure factors.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • c (tuple or list) – Ten anharmonic coefficients (C111, C222, C333, C112, C122, C113, C133, C223, C233, C123).

Returns:

Complex anharmonic correction factors of shape (N_reflections,).

Return type:

torch.Tensor

torchref.base.anharmonic_correction_no_complex(hkl, c)[source]

Apply anharmonic (third-order) correction without complex numbers.

Returns real and imaginary parts as separate rows.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • c (tuple or list) – Ten anharmonic coefficients (C111, C222, C333, C112, C122, C113, C133, C223, C233, C123).

Returns:

Correction factors as [cos, sin] of shape (2, N_reflections).

Return type:

torch.Tensor

torchref.base.core_deformation(core_correction, s)[source]

Apply core deformation correction to scattering.

Parameters:
Returns:

Core deformation correction factors.

Return type:

torch.Tensor

torchref.base.multiplication_quasi_complex_tensor(a, b)[source]

Multiply two quasi-complex tensors represented as [real, imag] rows.

Parameters:
  • a (torch.Tensor) – First quasi-complex tensor of shape (2, N).

  • b (torch.Tensor) – Second quasi-complex tensor of shape (2, N).

Returns:

Product as [real, imag] of shape (2, N).

Return type:

torch.Tensor

torchref.base.vectorized_add_to_map(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Add atoms to density map using ITC92 Gaussian parameterization.

Automatically selects the optimal implementation based on device. GPU default: Triton fused kernel (3-6x faster, falls back to JIT if Triton is unavailable). Override with TORCHREF_ATOM_PLACEMENT_GPU_MODE=jit or simple.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map (modified in-place).

Return type:

torch.Tensor

torchref.base.vectorized_add_to_map_aniso(surrounding_coords, voxel_indices, map, xyz, U, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Add anisotropic atoms to density map using ITC92 Gaussian parameterization.

Uses the same convention as the isotropic case for consistency: - B_total = (B_itc92 + B_atomic) / 4 - rho = A × (π/B_total)^(3/2) × exp(-π² r² / B_total)

For anisotropic atoms, this generalizes to: - B_atomic_ij = 8π² × U_atomic_ij (standard crystallographic conversion) - B_total_ij = (B_itc92 × δ_ij + 8π² × U_atomic_ij) / 4 - Normalization: (π³ / det(B_total))^(1/2) - Exponent: exp(-π² × r^T × B_total^(-1) × r)

Parameters:
  • surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • map (torch.Tensor) – Electron density map of shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates of shape (N_atoms, 3).

  • U (torch.Tensor) – Anisotropic displacement parameters in Angstroms squared (u11, u22, u33, u12, u13, u23) of shape (N_atoms, 6).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients for each atom of shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients (b parameters) in Angstroms squared for each atom of shape (N_atoms, 5).

  • occ (torch.Tensor) – Occupancies for each atom of shape (N_atoms,).

Returns:

Updated electron density map.

Return type:

torch.Tensor

torchref.base.scatter_add_nd(source, index, map)[source]

Vectorized n-dimensional scatter add operation.

Parameters:
  • source (torch.Tensor) – Values to add to the map of shape (N,).

  • index (torch.Tensor) – Indices where values should be added of shape (N, ndim).

  • map (torch.Tensor) – N-dimensional tensor of shape (d1, d2, …, dn) to add values into.

Returns:

Modified map with values added.

Return type:

torch.Tensor

torchref.base.scatter_add_nd_super_slow(source, index, map)[source]

Non-vectorized n-dimensional scatter add operation (slow reference implementation).

Parameters:
  • source (torch.Tensor) – Values to add to the map of shape (N,).

  • index (torch.Tensor) – Indices where values should be added of shape (N, ndim).

  • map (torch.Tensor) – N-dimensional tensor to add values into.

Returns:

Modified map with values added.

Return type:

torch.Tensor

torchref.base.find_relevant_voxels(real_space_grid, xyz, radius_angstrom=4, inv_frac_matrix=None)[source]

Identify surrounding voxels of atoms in a real space grid.

This is a vectorized function that finds all voxels within a spherical radius around each atom position.

Parameters:
  • real_space_grid (torch.Tensor) – Real space grid containing xyz coordinates at each grid point, of shape (nx, ny, nz, 3).

  • xyz (torch.Tensor) – Atom coordinates in real space (Cartesian coordinates), of shape (N, 3) or (3,).

  • radius_angstrom (float, optional) – Radius around each atom in Angstroms. Default is 4.

  • inv_frac_matrix (torch.Tensor, optional) – Matrix to convert Cartesian to fractional coordinates of shape (3, 3). Required for proper handling of non-orthogonal cells.

Returns:

  • surrounding_coords (torch.Tensor) – Coordinates of surrounding voxels for each atom of shape (N, R, 3), where R is the number of voxels within the radius.

  • voxel_indices_wrapped (torch.Tensor) – Wrapped voxel indices of shape (N, R, 3).

Notes

Atom coordinates are NOT wrapped here - periodic boundary conditions are handled in smallest_diff() which finds the minimum image distance. We only wrap voxel indices to ensure they’re valid array indices.

torchref.base.excise_angstrom_radius_around_coord(real_space_grid, start_indices, radius_angstrom=4.0)[source]

Identify voxel indices within an Angstrom radius around specified grid positions.

Parameters:
  • real_space_grid (torch.Tensor) – Real space grid of shape (nx, ny, nz, 3) containing xyz coordinates.

  • start_indices (torch.Tensor) – Starting grid indices of shape (N, 3) or (3,).

  • radius_angstrom (float, optional) – Radius in Angstroms. Default is 4.0.

Returns:

Wrapped voxel indices of shape (N, R, 3), where R is the number of voxels within the radius.

Return type:

torch.Tensor

Notes

Periodic boundary conditions are handled by wrapping the indices to ensure they’re valid array indices.

torchref.base.add_to_solvent_mask(surrounding_coords, voxel_indices, mask, xyz, radius, inv_frac_matrix, frac_matrix)[source]

Create solvent mask by placing spheres around atom positions.

Parameters:
  • surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • mask (torch.Tensor) – Solvent mask to be updated of shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions of shape (N_atoms, 3).

  • radius (float) – Radius of the sphere around each atom in Angstroms.

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Updated solvent mask as boolean tensor.

Return type:

torch.Tensor

torchref.base.add_to_phenix_mask(surrounding_coords, voxel_indices, xyz, vdw_radii, solvent_radius, inv_frac_matrix, frac_matrix, grid_shape, device)[source]

Create Phenix-style three-valued mask by placing spheres around atom positions.

This is a vectorized implementation that processes all atoms and voxels at once. Creates two binary masks: - protein_mask: 1 where inside VdW radius (protein core) - boundary_mask: 1 where between VdW and VdW+solvent_radius (accessible surface)

Final three-valued mask: - 0: protein_mask == 1 (protein core) - -1: boundary_mask == 1 and protein_mask == 0 (accessible surface) - 1: both masks == 0 (bulk solvent)

Parameters:
  • surrounding_coords (torch.Tensor) – Fractional coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Grid indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • xyz (torch.Tensor) – Atom positions in fractional coordinates of shape (N_atoms, 3).

  • vdw_radii (torch.Tensor) – VdW radius for each atom in Angstroms of shape (N_atoms,).

  • solvent_radius (float) – Probe radius in Angstroms (added to VdW to get accessible surface).

  • inv_frac_matrix (torch.Tensor) – Inverse fractional matrix for distance calculations of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractional matrix for distance calculations of shape (3, 3).

  • grid_shape (tuple) – Shape of the output mask (nx, ny, nz).

  • device (torch.device) – Device for tensor operations.

Returns:

  • protein_mask (torch.Tensor) – Boolean mask for protein core of shape grid_shape.

  • boundary_mask (torch.Tensor) – Boolean mask for accessible surface of shape grid_shape.

torchref.base.find_solvent_voids(mask, periodic=True)[source]

Identify void regions in a 3D boolean tensor using connected component analysis.

A void is defined as a connected region of False values (solvent). With periodic boundary conditions, voids can wrap around the edges of the array (like in a crystallographic unit cell). Without periodic boundaries, only enclosed voids are detected.

Parameters:
  • mask (torch.Tensor or numpy.ndarray) – Boolean tensor of shape (nx, ny, nz) where True indicates solid regions (e.g., protein) and False indicates empty regions (e.g., solvent). Can be either PyTorch tensor or NumPy array.

  • periodic (bool, optional) – If True, apply periodic boundary conditions (voids can wrap around edges). If False, only detect voids that are completely enclosed and don’t touch the boundaries. Default is True.

Returns:

Dictionary where keys are int volumes (number of voxels) of each void in the original array, and values are boolean masks (torch.Tensor or numpy.ndarray) of same shape as input with True only for that specific void region. Returns an empty dict if no voids are found.

Return type:

dict

Examples

import torch
# Create a simple 5x5x5 grid with a void in the center
mask = torch.ones(5, 5, 5, dtype=torch.bool)
mask[2, 2, 2] = False  # Single void voxel
voids = find_solvent_voids(mask)
print(voids)

{1: tensor([[[False, False, …]], dtype=torch.bool)}

Notes

  • Uses scipy.ndimage.label for connected component analysis.

  • Connectivity is 26-connected (face, edge, and corner neighbors).

  • With periodic=True, the array is padded by wrapping to detect cross-boundary voids.

  • Performance is O(n) where n is the total number of voxels.

  • With periodic boundaries, large percolating voids are still detected.

torchref.base.fft(reciprocal_grid, volume=None)[source]

Perform FFT to obtain real space electron density.

Uses fftn with norm=”forward” to match crystallographic sign convention directly, avoiding expensive flip/roll operations.

Crystallographic convention: ρ(r) = (1/V) Σ F(h) exp(-2πi h·r)

PyTorch fftn with norm=”forward” gives:

fftn(x)[n] = (1/N) Σ_k x[k] exp(-2πi k·n/N)

When input structure factors F are correctly scaled (with V/N factor from ifft), we need to multiply by N/V to recover the original electron density:

ρ = fftn(F) * (N / V)

Parameters:
  • reciprocal_grid (torch.Tensor) – Reciprocal space grid of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz). Expected to contain correctly scaled structure factors (from ifft with volume).

  • volume (float, optional) – Unit cell volume in ų. If provided, result is scaled by N/V to give correctly normalized electron density.

Returns:

Real-valued tensor of electron density with same shape as input.

Return type:

torch.Tensor

torchref.base.ifft(real_space_map, volume=None)[source]

Perform inverse FFT to obtain reciprocal space structure factors.

Crystallographic convention: F(h) = Σ ρ(r) exp(+2πi h·r) * ΔV where ΔV = V_cell / N is the voxel volume.

PyTorch ifftn with norm=”forward” gives unnormalized DFT:

DFT[k] = Σ x[n] exp(+2πi k·n/N)

To obtain correctly scaled structure factors, we multiply by voxel volume:

F(h) = DFT(ρ) * (V_cell / N)

Parameters:
  • real_space_map (torch.Tensor) – Real space electron density map of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz).

  • volume (float, optional) – Unit cell volume in ų. If provided, result is scaled by voxel volume (V_cell / N_total) to give correctly normalized structure factors.

Returns:

Complex-valued tensor of structure factors with same shape as input.

Return type:

torch.Tensor

torchref.base.get_real_grid(cell=None, fractional_matrix=None, max_res=0.8, gridsize=None, device=None)[source]

Generate a real space grid for electron density calculations.

Parameters:
  • cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • fractional_matrix (torch.Tensor, optional) – Pre-computed fractionalization matrix.

  • max_res (float, optional) – Maximum resolution for automatic grid sizing. Default is 0.8.

  • gridsize (torch.Tensor or array-like, optional) – Explicit grid dimensions [nx, ny, nz]. If None, calculated from max_res.

  • device (torch.device or str, optional) – Device for tensor placement. If None, inferred from fractional_matrix or cell (whichever tensor is provided); falls back to CPU.

Returns:

Real space grid of shape (nx, ny, nz, 3) containing Cartesian coordinates.

Return type:

torch.Tensor

torchref.base.find_grid_size(cell, max_res)[source]

Calculate grid size based on unit cell and resolution.

Parameters:
  • cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • max_res (float) – Maximum resolution in Angstroms.

Returns:

Grid dimensions [nx, ny, nz] as int32.

Return type:

torch.Tensor

torchref.base.get_real_grid_numpy(cell, max_res=0.8, gridsize=None)[source]

Generate a real-space grid of Cartesian coordinates (NumPy version).

Creates a 3D grid in fractional coordinates and converts it to Cartesian coordinates. Grid points are placed at cell edges following CCTBX convention.

Parameters:
  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

  • max_res (float, optional) – Maximum resolution in Angstroms for grid spacing. Default is 0.8. Ignored if gridsize is provided.

  • gridsize (list or numpy.ndarray, optional) – Explicit grid dimensions [nx, ny, nz]. If provided, overrides max_res.

Returns:

Real-space grid coordinates with shape (nx, ny, nz, 3).

Return type:

numpy.ndarray

torchref.base.get_grids(cell, max_res=0.8)[source]

Generate real-space and reciprocal-space grids for Fourier transforms.

Creates a 3D grid in fractional coordinates and converts it to Cartesian coordinates, along with an empty reciprocal space grid.

Parameters:
  • cell (numpy.ndarray or list) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

  • max_res (float, optional) – Maximum resolution in Angstroms for grid spacing. Default is 0.8.

Returns:

  • recgrid (numpy.ndarray) – Empty reciprocal space grid with shape determined by resolution.

  • xyz_real_grid (numpy.ndarray) – Real-space grid coordinates with shape (nx, ny, nz, 3).

torchref.base.put_hkl_on_grid(real_space_grid, diff, hkl)[source]

Place structure factors on a reciprocal space grid.

Maps structure factor values to their corresponding positions on a 3D reciprocal space grid based on Miller indices.

Parameters:
  • real_space_grid (numpy.ndarray) – Real-space grid used to determine the reciprocal grid dimensions. Shape should be (nx, ny, nz, 3) or similar.

  • diff (numpy.ndarray) – Structure factor values (complex) to place on the grid.

  • hkl (numpy.ndarray) – Miller indices with shape (N, 3), used as grid indices.

Returns:

Complex reciprocal space grid with shape (nx, ny, nz).

Return type:

numpy.ndarray

torchref.base.get_scattering_factors(scattering_dict, elements)[source]

Get scattering factors from a pre-computed dictionary.

Parameters:
  • scattering_dict (dict) – Dictionary of scattering factors by element.

  • elements (list) – List of element symbols.

Returns:

Concatenated scattering factors.

Return type:

torch.Tensor

torchref.base.get_scattering_factors_unique(atoms, s)[source]

Compute unique scattering factors for a set of atoms.

Parameters:
  • atoms (DataFrame-like) – Atoms with ‘element’ and ‘charge’ attributes.

  • s (array-like) – Scattering vector magnitudes.

Returns:

Dictionary mapping element symbols to scattering factors.

Return type:

dict

torchref.base.get_parametrization_for_elements(elements, charges=None)[source]

Get ITC92 parametrization for a list of elements.

Useful for getting parametrization for specific atoms without a full DataFrame.

Parameters:
  • elements (list of str) – Element symbols (e.g., [‘C’, ‘N’, ‘O’])

  • charges (list of int, optional) – Charges for each element (default: all zeros)

Returns:

dict

Return type:

{element: (A, B, C)}

torchref.base.calc_scattering_factors_paramtetrization(parametrization, s, atom_list)[source]

Calculate scattering factors from ITC92 parametrization.

Parameters:
  • parametrization (dict) – Dictionary of (A, B, C) tuples by element.

  • s (torch.Tensor) – Scattering vector magnitudes.

  • atom_list (list) – List of atom symbols.

Returns:

Scattering factors.

Return type:

torch.Tensor

torchref.base.compute_radial_shells(d_min, d_max, n_shells, device=None)[source]

Compute uniform radial shell boundaries in reciprocal space.

Shells are spaced uniformly in 1/d (|s|) for even coverage of resolution.

Parameters:
  • d_min (float) – High resolution limit in Angstroms.

  • d_max (float) – Low resolution limit in Angstroms.

  • n_shells (int) – Number of radial shells.

  • device (torch.device, optional) – Device for output tensors. Default is CPU.

Returns:

  • shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).

  • shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).

Return type:

Tuple[Tensor, Tensor]

torchref.base.assign_to_shells(s_mag, shell_edges)[source]

Assign reflections to radial shells.

Parameters:
  • s_mag (torch.Tensor) – |s| values in Angstroms^-1, shape (N,).

  • shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).

Returns:

shell_idx – Shell index for each reflection, shape (N,). Values 0 to n_shells-1, or -1 for out-of-range.

Return type:

torch.Tensor

torchref.base.compute_anisotropy_correction(s_vectors, U)[source]

Compute anisotropic correction factor for F² values.

The correction is: exp(-2*pi^2 * s^T U s)

This scales F² values to correct for anisotropic diffraction before normalizing to E-values.

Parameters:
  • s_vectors (torch.Tensor) – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

  • U (torch.Tensor) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).

Returns:

correction – Correction factors, shape (N,).

Return type:

torch.Tensor

torchref.base.compute_shell_cv(F_squared, shell_idx, n_shells, min_count=10)[source]

Compute mean coefficient of variation of F² values within resolution shells.

After proper anisotropy correction, F² should have similar CV in all directions within each resolution shell.

Parameters:
  • F_squared (torch.Tensor) – F² values, shape (N,).

  • shell_idx (torch.Tensor) – Shell assignments, shape (N,).

  • n_shells (int) – Number of shells.

  • min_count (int) – Minimum reflections per shell to include in calculation.

Returns:

mean_cv – Mean coefficient of variation across shells.

Return type:

float

torchref.base.fit_anisotropy_correction(F_squared, s_vectors, n_shells=20, d_min=4.0, d_max=50.0, n_iterations=100, lr=0.01, verbose=True)[source]

Fit anisotropy correction parameters to minimize variance within shells.

Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell, making the distribution more isotropic before E-value conversion.

Uses PyTorch’s LBFGS optimizer for efficient optimization.

Parameters:
  • F_squared (torch.Tensor) – F² values, shape (N,).

  • s_vectors (torch.Tensor) – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

  • n_shells (int) – Number of resolution shells for variance calculation.

  • d_min (float) – High resolution limit in Angstroms.

  • d_max (float) – Low resolution limit in Angstroms.

  • n_iterations (int) – Number of optimization iterations.

  • lr (float) – Learning rate for optimizer.

  • verbose (bool) – Print progress.

Returns:

  • U_optimal (torch.Tensor) – Optimal anisotropic parameters, shape (6,).

  • final_cv (float) – Final mean coefficient of variation after correction.

Return type:

Tuple[Tensor, float]

torchref.base.apply_anisotropy_correction(F_squared, s_vectors, U)[source]

Apply anisotropic correction to F² values.

Should be called BEFORE converting F² to E-values.

Parameters:
  • F_squared (torch.Tensor) – F² values, shape (N,).

  • s_vectors (torch.Tensor) – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

  • U (torch.Tensor) – Anisotropic parameters, shape (6,).

Returns:

F2_corrected – Corrected F² values, shape (N,).

Return type:

torch.Tensor

torchref.base.F_squared_to_E_values(F_squared, s_vectors, n_shells=20, d_min=None, d_max=None)[source]

Convert F² values to E-values by normalizing within resolution shells.

E-values are normalized structure factors where <E²> = 1 within each resolution shell.

Parameters:
  • F_squared (torch.Tensor) – F² values, shape (N,).

  • s_vectors (torch.Tensor) – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

  • n_shells (int) – Number of resolution shells for normalization.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, derived from s_vectors.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, derived from s_vectors.

Returns:

  • E_values (torch.Tensor) – Normalized E-values, shape (N,).

  • E_squared (torch.Tensor) – E² values (for correlation calculations), shape (N,).

  • shell_idx (torch.Tensor) – Shell assignments for each reflection, shape (N,).

Return type:

Tuple[Tensor, Tensor, Tensor]

torchref.base.rotate_coords_torch(coords, phi, rho)[source]

Rotate coordinates using phi and rho angles (PyTorch version).

Parameters:
  • coords (torch.Tensor) – Coordinates of shape (N, 3) to rotate.

  • phi (float) – Rotation angle phi in degrees.

  • rho (float) – Rotation angle rho in degrees.

Returns:

Rotated coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.rotate_coords_numpy(coords, phi, rho)[source]

Rotate 3D coordinates by phi and rho angles (NumPy version).

Applies a rotation transformation to a set of 3D coordinates using two rotation angles (phi and rho) in degrees.

Parameters:
  • coords (numpy.ndarray) – Array of 3D coordinates with shape (N, 3).

  • phi (float) – First rotation angle in degrees.

  • rho (float) – Second rotation angle in degrees.

Returns:

Rotated coordinates with the same shape as input (N, 3).

Return type:

numpy.ndarray

torchref.base.axis_angle_to_rotation_matrix(axis_angle)[source]

Convert axis-angle representation to 3x3 rotation matrix.

Uses Rodrigues’ formula. Supports batched input and gradients.

Parameters:

axis_angle (torch.Tensor) – Axis-angle representation with shape (3,) or (N, 3). Direction is rotation axis, magnitude is angle in radians.

Returns:

Rotation matrix with shape (3, 3) or (N, 3, 3).

Return type:

torch.Tensor

torchref.base.rotation_matrix_to_axis_angle(R)[source]

Convert 3x3 rotation matrix to axis-angle representation.

Parameters:

R (torch.Tensor) – Rotation matrix with shape (3, 3) or (N, 3, 3).

Returns:

Axis-angle representation with shape (3,) or (N, 3).

Return type:

torch.Tensor

torchref.base.quaternion_to_rotation_matrix(q)[source]

Convert quaternion to rotation matrix.

Parameters:

q (torch.Tensor) – Quaternion with shape (4,) or (N, 4). Format: [w, x, y, z].

Returns:

Rotation matrix with shape (3, 3) or (N, 3, 3).

Return type:

torch.Tensor

torchref.base.random_rotation_uniform(n=1, device=device(type='cpu'), dtype=torch.float32)[source]

Generate uniform random rotations over SO(3).

Uses Shoemake’s quaternion-based uniform sampling.

Parameters:
  • n (int, optional) – Number of rotations to generate. Default is 1.

  • device (str, optional) – Device for output tensor. Defaults to the configured device.current.

  • dtype (torch.dtype, optional) – Data type for output tensor. Default is dtypes.float.

Returns:

Rotation matrices with shape (n, 3, 3) or (3, 3) if n=1.

Return type:

torch.Tensor

torchref.base.superpose_vectors_robust_torch(ref_coords, mov_coords, weights=None, max_iterations=10)[source]

Perform weighted superposition of two coordinate sets using SVD (PyTorch version).

Parameters:
  • ref_coords (torch.Tensor) – Reference coordinates of shape (N, 3).

  • mov_coords (torch.Tensor) – Mobile coordinates of shape (N, 3) to be superposed onto reference.

  • weights (torch.Tensor, optional) – Weights for each atom of shape (N, 1). Default is uniform weights.

  • max_iterations (int, optional) – Maximum number of iterations for refinement. Default is 10.

Returns:

4x4 transformation matrix (shape (3, 4) returned).

Return type:

torch.Tensor

torchref.base.superpose_vectors_robust(target_coords, mobile_coords, weights=None, max_iterations=1)[source]

Superpose mobile coordinates onto target coordinates using the Kabsch algorithm.

Computes the optimal rotation and translation to minimize the weighted RMSD between two sets of 3D coordinates, with robust handling of special cases such as reflection.

Parameters:
  • target_coords (numpy.ndarray) – Target coordinates with shape (N, 3).

  • mobile_coords (numpy.ndarray) – Mobile coordinates with shape (N, 3) to be superposed onto target.

  • weights (numpy.ndarray, optional) – Per-atom weights for the superposition with shape (N,). Default is uniform weights.

  • max_iterations (int, optional) – Number of iterations for refinement. Default is 1 (standard Kabsch).

Returns:

  • transformation_matrix (numpy.ndarray) – 4x4 transformation matrix that maps mobile_coords onto target_coords.

  • rmsd (float) – Weighted root-mean-square deviation after superposition.

Raises:

ValueError – If input coordinate arrays have different shapes.

Notes

The algorithm uses SVD decomposition of the covariance matrix and handles the reflection case by checking the determinant of the rotation matrix.

torchref.base.align_torch(xyz1, xyz2, idx_to_move=None)[source]

Align two coordinate sets using superposition (PyTorch version).

Parameters:
  • xyz1 (torch.Tensor) – Target coordinates of shape (N, 3).

  • xyz2 (torch.Tensor) – Coordinates to be aligned of shape (N, 3).

  • idx_to_move (torch.Tensor, optional) – Indices of atoms to use for alignment. If None, uses all atoms.

Returns:

Aligned coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.align_pdbs(pdb1, pdb2, Atoms=None)[source]

Align two PDB structures using the Kabsch algorithm.

Superimposes pdb2 onto pdb1 by minimizing the RMSD between corresponding atoms. The transformation is applied in-place to pdb2.

Parameters:
  • pdb1 (pandas.DataFrame) – Reference PDB structure with ‘x’, ‘y’, ‘z’, ‘name’, and ‘tempfactor’ columns.

  • pdb2 (pandas.DataFrame) – Mobile PDB structure to be aligned onto pdb1.

  • Atoms (list, optional) – List of atom names to use for alignment. If None, all atoms are used.

Returns:

  • pdb2 (pandas.DataFrame) – Transformed pdb2 with updated coordinates.

  • rmsd (float) – Root-mean-square deviation after alignment.

torchref.base.get_alignment_matrix(pdb1, pdb2, Atoms=None)[source]

Calculate the transformation matrix to align two PDB structures.

Computes the 4x4 transformation matrix that would superimpose pdb2 onto pdb1 without actually applying the transformation.

Parameters:
  • pdb1 (pandas.DataFrame) – Reference PDB structure with ‘x’, ‘y’, ‘z’, ‘name’, and ‘tempfactor’ columns.

  • pdb2 (pandas.DataFrame) – Mobile PDB structure.

  • Atoms (list, optional) – List of atom names to use for alignment. If None, all atoms are used.

Returns:

  • transformation_matrix (numpy.ndarray) – 4x4 transformation matrix.

  • rmsd (float) – Root-mean-square deviation that would result from the alignment.

torchref.base.apply_transformation(points, transformation_matrix)[source]

Apply a 4x4 transformation matrix to 3D points (PyTorch version).

Parameters:
  • points (torch.Tensor) – 3D points of shape (N, 3).

  • transformation_matrix (torch.Tensor) – Transformation matrix of shape (3, 4) or (4, 4).

Returns:

Transformed 3D points of shape (N, 3).

Return type:

torch.Tensor

torchref.base.apply_transformation_numpy(points, transformation_matrix)[source]

Apply a 4x4 transformation matrix to 3D points (NumPy version).

Converts points to homogeneous coordinates, applies the transformation, and returns the transformed 3D coordinates.

Parameters:
  • points (numpy.ndarray) – 3D coordinates with shape (N, 3).

  • transformation_matrix (numpy.ndarray) – 4x4 transformation matrix containing rotation and translation.

Returns:

Transformed 3D coordinates with shape (N, 3).

Return type:

numpy.ndarray

torchref.base.invert_transformation_matrix(transformation_matrix)[source]

Compute the inverse of a 4x4 transformation matrix.

Efficiently inverts a rigid-body transformation matrix by transposing the rotation component and computing the inverse translation.

Parameters:

transformation_matrix (numpy.ndarray) – 4x4 transformation matrix containing rotation (top-left 3x3) and translation (top-right 3x1).

Returns:

Inverse 4x4 transformation matrix.

Return type:

numpy.ndarray

Notes

This function assumes the input is a valid rigid-body transformation (rotation + translation). For such matrices, the inverse rotation is simply the transpose, and the inverse translation is computed as -R^T @ t.

torchref.base.get_rfactor_torch(F_obs, F_calc)[source]

Calculate R-factor between observed and calculated structure factors (PyTorch version).

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes.

Returns:

R-factor value.

Return type:

torch.Tensor

torchref.base.get_rfactor(F_obs, F_calc)[source]

Calculate the R-factor between observed and calculated structure factors (NumPy version).

The R-factor is a measure of agreement between observed and calculated structure factor amplitudes, defined as sum(|F_obs - F_calc|) / sum(F_obs).

Parameters:
  • F_obs (numpy.ndarray) – Observed structure factor amplitudes.

  • F_calc (numpy.ndarray) – Calculated structure factor amplitudes.

Returns:

R-factor value between 0 and 1.

Return type:

float

torchref.base.rfactor(F_obs, F_calc)[source]

Calculate R-factor between observed and calculated structure factors.

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes of shape (N,).

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes of shape (N,).

Returns:

R-factor value.

Return type:

float

torchref.base.get_rfactors(F_obs, F_calc, rfree)[source]

Get R-factors for working and test sets.

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes of shape (N,).

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes of shape (N,).

  • rfree (torch.Tensor) – Boolean mask indicating R-free reflections of shape (N,). 1 is Working set, 0 is Test set.

Returns:

(r_work, r_test) where r_work is the R-factor for the working set and r_test is the R-factor for the test set.

Return type:

tuple

torchref.base.bin_wise_rfactors(F_obs, F_calc, rfree, bins)[source]

Calculate bin-wise R-factors between observed and calculated structure factors.

Parameters:
Returns:

  • r_work_bins (torch.Tensor) – R-factors for working set (per bin).

  • r_test_bins (torch.Tensor) – R-factors for test set (per bin).

Return type:

tuple

torchref.base.calc_outliers(F_obs, F_calc, z)[source]

Identify outlier reflections based on deviation from expected values (PyTorch version).

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes.

  • z (float) – Number of standard deviations for outlier threshold.

Returns:

Boolean mask where True indicates outlier reflections.

Return type:

torch.Tensor

torchref.base.calc_outliers_numpy(F_obs, F_calc, z)[source]

Identify outlier reflections based on structure factor differences (NumPy version).

Detects reflections where the normalized difference between observed and calculated structure factors exceeds z standard deviations.

Parameters:
  • F_obs (numpy.ndarray) – Observed structure factor amplitudes.

  • F_calc (numpy.ndarray) – Calculated structure factor amplitudes.

  • z (float) – Number of standard deviations for outlier threshold.

Returns:

Boolean array where True indicates an outlier reflection.

Return type:

numpy.ndarray

torchref.base.nll_xray(F_obs, F_calc, sigma_F_obs)[source]

Compute X-ray negative log-likelihood assuming Gaussian distribution.

Parameters:
  • F_obs (torch.Tensor or MaskedTensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor or MaskedTensor) – Standard deviations of observed amplitudes.

Returns:

Mean negative log-likelihood.

Return type:

torch.Tensor

torchref.base.nll_xray_sum(F_obs, F_calc, sigma_F_obs)[source]

Compute summed X-ray negative log-likelihood assuming Gaussian distribution.

Parameters:
  • F_obs (torch.Tensor or MaskedTensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor or MaskedTensor) – Standard deviations of observed amplitudes.

Returns:

Sum of negative log-likelihood values.

Return type:

torch.Tensor

torchref.base.nll_xray_lognormal(F_obs, F_calc, sigma_F_obs, eps=1e-10)[source]

Compute X-ray negative log-likelihood assuming lognormal distribution.

This is a more realistic model for structure factor amplitudes, which must be positive. For a lognormal distribution LogNormal(mu, sigma^2), the NLL is: NLL = 0.5*(log(x) - mu)^2/sigma^2 + log(x) + log(sigma) + 0.5*log(2*pi)

Where mu and sigma are derived from F_obs and sigma_F_obs using: - sigma = sqrt(log(1 + (sigma_F/F)^2)) - mu = log(F) - sigma^2/2

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor) – Standard deviations of observed amplitudes.

  • eps (float, optional) – Small value to avoid numerical issues. Default is 1e-10.

Returns:

Mean negative log-likelihood.

Return type:

torch.Tensor

torchref.base.log_loss(F_obs, F_calc, sigma_F_obs)[source]

Compute log-space loss between observed and calculated structure factors.

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor) – Standard deviations of observed amplitudes (unused).

Returns:

Mean absolute difference in log space.

Return type:

torch.Tensor

torchref.base.estimate_sigma_I(I)[source]

Estimate standard deviation of intensities.

Separates positive and negative values for robust estimation.

Parameters:

I (torch.Tensor) – Intensity values.

Returns:

Estimated standard deviations.

Return type:

torch.Tensor

torchref.base.estimate_sigma_F(F)[source]

Estimate standard deviation of structure factor amplitudes.

Parameters:

F (torch.Tensor) – Structure factor amplitudes.

Returns:

Estimated standard deviations.

Return type:

torch.Tensor

torchref.base.gaussian_to_lognormal_sigma(F, sigma_F, eps=1e-10)[source]

Approximate the sigma parameter of a lognormal distribution from Gaussian statistics.

If we assume F comes from a lognormal distribution X ~ LogNormal(mu, sigma^2), then: - Mean: E[X] = F - Std: sqrt(Var[X]) = sigma_F

For lognormal distribution: - E[X] = exp(mu + sigma^2/2) - Var(X) = exp(2*mu + sigma^2)(exp(sigma^2) - 1)

We can derive: - CV^2 = Var[X]/E[X]^2 = exp(sigma^2) - 1 - sigma = sqrt(log(1 + CV^2))

where CV = sigma_F/F is the coefficient of variation.

Parameters:
  • F (torch.Tensor) – Structure factor amplitudes (mean of the distribution).

  • sigma_F (torch.Tensor) – Standard deviations.

  • eps (float, optional) – Small value to avoid division by zero. Default is 1e-10.

Returns:

Sigma parameter for lognormal distribution.

Return type:

torch.Tensor

torchref.base.gaussian_to_lognormal_mu(F, sigma_lognormal, eps=1e-10)[source]

Calculate the mu parameter of a lognormal distribution given F and sigma.

For lognormal distribution X ~ LogNormal(mu, sigma^2): - E[X] = exp(mu + sigma^2/2)

Solving for mu: - mu = log(E[X]) - sigma^2/2

Parameters:
  • F (torch.Tensor) – Structure factor amplitudes (mean of the distribution).

  • sigma_lognormal (torch.Tensor) – Sigma parameter from lognormal distribution.

  • eps (float, optional) – Small value to avoid log of zero. Default is 1e-10.

Returns:

Mu parameter for lognormal distribution.

Return type:

torch.Tensor

torchref.base.compute_metric_tensor(frac_matrix)[source]

Compute the metric tensor for calculating r² in fractional coordinates.

The metric tensor G allows computing squared distances in Cartesian space from fractional coordinate differences:

r² = diff_frac @ G @ diff_frac.T

Parameters:

frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

Returns:

Metric tensor G = frac_matrix.T @ frac_matrix, shape (3, 3).

Return type:

torch.Tensor

torchref.base.precompute_fractional_coords(coords_cart, inv_frac_matrix)[source]

Convert Cartesian voxel coordinates to fractional coordinates.

Parameters:
  • coords_cart (torch.Tensor) – Cartesian coordinates, shape (N_atoms, N_voxels, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

Returns:

Fractional coordinates, shape (N_atoms, N_voxels, 3).

Return type:

torch.Tensor

torchref.base.warmup(device='auto')[source]

Pre-compile kernels to avoid compilation overhead during first use.

Parameters:

device (str) – Device to warmup: “cpu”, “cuda”, or “auto” (default).

torchref.base.get_cache_dir()[source]

Return the path to the JIT kernel cache directory.

torchref.base.clear_cache()[source]

Clear the JIT kernel cache.

torchref.base.warmup_cuda_operations(device='cuda')[source]

Warm up CUDA kernels to avoid lazy loading overhead.

This function runs dummy operations to trigger CUDA kernel compilation and loading, so subsequent operations don’t incur this overhead.

Call this once after moving model to GPU.

Parameters:

device (str) – Device to warm up. Default is “cuda”.

torchref.base.get_cached_radius_offsets(voxel_size, radius_angstrom, device)[source]

Get cached radius offsets to avoid recomputation.

This eliminates redundant computation when processing multiple atoms with the same voxel size and radius.

Parameters:
  • voxel_size (torch.Tensor) – Voxel dimensions, shape (3,).

  • radius_angstrom (float) – Radius in Angstroms.

  • device (torch.device) – Device for the output tensor.

Returns:

Voxel offsets within radius, shape (N_voxels, 3).

Return type:

torch.Tensor

Subpackages

Submodules