torchref.base.math_torch module

PyTorch implementations of mathematical functions for crystallography.

Deprecated since version This: module is deprecated. Please import from domain-specific submodules:

  • torchref.base.coordinates - Coordinate transformations

  • torchref.base.reciprocal - Reciprocal space calculations

  • torchref.base.direct_summation - Structure factor calculations

  • torchref.base.electron_density - Electron density map building

  • torchref.base.fourier - FFT operations

  • torchref.base.scattering - Atomic scattering factors

  • torchref.base.alignment - Coordinate alignment

  • torchref.base.metrics - R-factors and loss functions

  • torchref.base.kernels - Optimized kernels

This module is maintained for backward compatibility and re-exports all functions from the new submodules.

Example (new style - recommended):

from torchref.base.coordinates import cartesian_to_fractional_torch
from torchref.base.metrics import get_rfactor_torch

Example (old style - still works):

from torchref.base.math_torch import cartesian_to_fractional_torch
torchref.base.math_torch.cartesian_to_fractional_torch(xyz, cell, B_inv=None)[source]

Convert Cartesian coordinates to fractional coordinates.

Parameters:
  • xyz (torch.Tensor) – Cartesian coordinates of shape (N, 3).

  • cell (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • B_inv (torch.Tensor, optional) – Inverse fractionalization matrix. If None, it will be calculated from cell.

Returns:

Fractional coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.fractional_to_cartesian_torch(xyz_fractional, cell, B=None)[source]

Convert fractional coordinates to Cartesian coordinates.

Parameters:
  • xyz_fractional (torch.Tensor) – Fractional coordinates of shape (N, 3).

  • cell (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • B (torch.Tensor, optional) – Fractionalization matrix. If None, it will be calculated from cell.

Returns:

Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.get_fractional_matrix(cell)[source]

Calculate the fractional-to-Cartesian transformation matrix.

Constructs the matrix B that transforms fractional coordinates to Cartesian coordinates based on the unit cell parameters.

Parameters:

cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

3x3 transformation matrix B such that cart = frac @ B.T.

Return type:

torch.Tensor

torchref.base.math_torch.get_inv_fractional_matrix_torch(cell)[source]

Calculate the Cartesian-to-fractional transformation matrix (PyTorch version).

Computes the inverse of the fractional matrix for converting Cartesian coordinates to fractional coordinates.

Parameters:

cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma] where lengths are in Angstroms and angles are in degrees.

Returns:

3x3 inverse transformation matrix B_inv such that frac = cart @ B_inv.T.

Return type:

torch.Tensor

torchref.base.math_torch.smallest_diff(diff, inv_frac_matrix, frac_matrix)[source]

Compute minimum image squared distances with periodic boundary conditions.

Parameters:
  • diff (torch.Tensor) – Difference vectors of shape (…, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Squared distances with shape (…).

Return type:

torch.Tensor

torchref.base.math_torch.smallest_diff_aniso(diff, inv_frac_matrix, frac_matrix)[source]

Compute minimum image difference vectors for anisotropic calculations.

Parameters:
  • diff (torch.Tensor) – Difference vectors of shape (…, 3).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Signed difference vectors with shape (…, 3). Note: For anisotropic calculations, the signed vectors are needed to correctly compute the quadratic form r^T × B^(-1) × r with off-diagonal U tensor terms.

Return type:

torch.Tensor

torchref.base.math_torch.reciprocal_basis_matrix(cell)[source]

Compute the reciprocal space basis matrix from unit cell parameters.

Parameters:

cell (torch.Tensor) – Cell parameters [a, b, c, alpha, beta, gamma].

Returns:

Reciprocal basis matrix of shape (3, 3) with a*, b*, c* as rows.

Return type:

torch.Tensor

torchref.base.math_torch.get_scattering_vectors(hkl, cell, recB=None)[source]

Calculate scattering vectors from Miller indices.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • cell (torch.Tensor) – Cell parameters [a, b, c, alpha, beta, gamma].

  • recB (torch.Tensor, optional) – Pre-computed reciprocal basis matrix of shape (3, 3).

Returns:

Scattering vectors of shape (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.get_d_spacing(hkl, cell, recB=None)[source]

Calculate d-spacing from Miller indices.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • cell (torch.Tensor) – Cell parameters [a, b, c, alpha, beta, gamma].

  • recB (torch.Tensor, optional) – Pre-computed reciprocal basis matrix of shape (3, 3).

Returns:

D-spacing values of shape (N,) in Angstroms.

Return type:

torch.Tensor

torchref.base.math_torch.place_on_grid(hkls, structure_factor, grid_size, enforce_hermitian=True)[source]

Place structure factors on a reciprocal-space grid.

Vectorized placement of batched structure factors on reciprocal-space grid.

Parameters:
  • hkls (torch.Tensor) – Miller indices of shape (N, 3).

  • structure_factor (torch.Tensor) – Structure factors of shape (N,) or (B, N) for batched input.

  • grid_size (tuple or torch.Tensor) – Grid dimensions (Nx, Ny, Nz).

  • enforce_hermitian (bool, optional) – Whether to enforce Hermitian symmetry. Default is True.

Returns:

Complex tensor grid of structure factors of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz) for batched input.

Return type:

torch.Tensor

torchref.base.math_torch.extract_structure_factor_from_grid(reciprocal_grid, hkls)[source]

Extract structure factors from reciprocal space grid at given Miller indices.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz).

  • hkls (torch.Tensor) – Miller indices of shape (N, 3).

Returns:

Structure factors of shape (N,) or (B, N) for batched input.

Return type:

torch.Tensor

torchref.base.math_torch.apply_translation_phase(F_calc, hkl, translation_frac)[source]

Apply translation phase shift to structure factors.

For a translation t in fractional coordinates, the structure factor transforms as: F’(hkl) = F(hkl) * exp(2πi * hkl · t)

Parameters:
  • F_calc (torch.Tensor) – Complex structure factors of shape (N,).

  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • translation_frac (torch.Tensor) – Translation vector in fractional coordinates of shape (3,).

Returns:

Phase-shifted structure factors of shape (N,).

Return type:

torch.Tensor

torchref.base.math_torch.interpolate_structure_factor_from_grid(reciprocal_grid, hkl_float, interpolate_amplitude=True)[source]

Interpolate structure factors from reciprocal space grid at non-integer positions.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz).

  • hkl_float (torch.Tensor) – Non-integer HKL positions of shape (N, 3).

  • interpolate_amplitude (bool, optional) – If True (default), interpolate amplitudes instead of complex values. This avoids phase cancellation issues where linear interpolation of complex numbers with different phases can give incorrect magnitudes (e.g., interpolating F=1 and F=-1 gives 0 instead of 1).

Returns:

Interpolated structure factors of shape (N,). If interpolate_amplitude=True, returns real-valued amplitudes. If interpolate_amplitude=False, returns complex values (use with caution).

Return type:

torch.Tensor

Notes

For a rotation R applied to the model, the structure factor at hkl becomes F(R^T @ hkl), so you would call this with hkl_float = hkl @ R.

WARNING: Complex interpolation (interpolate_amplitude=False) can give incorrect results when neighboring voxels have very different phases. For example, if F1 = A*exp(i*0) and F2 = A*exp(i*π), linear interpolation gives magnitude 0 at the midpoint instead of A. Use interpolate_amplitude=True for rotation functions where only magnitudes matter.

torchref.base.math_torch.interpolate_complex_from_grid(reciprocal_grid, hkl_float)[source]

Interpolate complex structure factors from reciprocal space grid.

Unlike amplitude interpolation, this preserves phase information, which is essential for translation searches where phases are used to compute correlation functions.

Parameters:
  • reciprocal_grid (torch.Tensor) – Complex tensor of shape (Nx, Ny, Nz).

  • hkl_float (torch.Tensor) – Non-integer HKL positions of shape (N, 3).

Returns:

Interpolated complex structure factors of shape (N,).

Return type:

torch.Tensor

Notes

This function performs trilinear interpolation of complex values. For rotation-only searches (where only magnitudes matter), use interpolate_structure_factor_from_grid(interpolate_amplitude=True) instead.

For translation searches, complex interpolation is needed because the translation function depends on the phase relationship between F_obs and F_calc.

WARNING: Complex interpolation can give reduced magnitudes when neighboring voxels have very different phases. This is acceptable for translation searches where we’re computing correlation functions, but not for rotation searches.

torchref.base.math_torch.trilinear_interpolate_patterson(grid, points, chunk_size=10000000)[source]

Memory-efficient trilinear interpolation on a 3D grid.

Pure torch implementation for GPU acceleration and gradient flow. Replaces scipy.ndimage.map_coordinates for torch tensors.

Parameters:
  • grid (torch.Tensor) – 3D grid of values with shape (nx, ny, nz).

  • points (torch.Tensor) – Coordinates to sample with shape (n_points, 3). Should be in fractional coordinates [0, 1) for ‘wrap’ mode. Or batch K, n_points, 3 for multiple batches.

  • chunk_size (int, optional) – Number of points to process at once. Default is 1M.

Returns:

Interpolated values with shape (n_points,) or (batch, n_points).

Return type:

torch.Tensor

Notes

Supports automatic differentiation for gradient-based optimization. Uses chunked processing and in-place accumulation to reduce memory.

torchref.base.math_torch.iso_structure_factor_torched(hkl, s, xyz_fractional, occ, scattering_factors, adp, spacegroup, max_memory_gb=None, A=None, B_coeff=None)[source]

Calculate isotropic structure factors using PyTorch.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s (torch.Tensor) – Scattering vector magnitudes of shape (N_reflections,).

  • xyz_fractional (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor or None) – Atomic scattering factors of shape (N_reflections, N_atoms). If None, must provide A and B_coeff to compute them in batches.

  • adp (torch.Tensor) – Atomic displacement parameters (isotropic) of shape (N_atoms,).

  • spacegroup (callable) – Space group symmetry operator function.

  • max_memory_gb (float, optional) – Maximum memory to use in GB. If None, no batching is applied.

  • A (torch.Tensor, optional) – ITC92 A coefficients (N_atoms, 5) for computing scattering factors. Required if scattering_factors is None and batching is needed.

  • B_coeff (torch.Tensor, optional) – ITC92 B coefficients (N_atoms, 5) for computing scattering factors. Required if scattering_factors is None and batching is needed.

Returns:

Complex structure factors of shape (N_reflections,).

Return type:

torch.Tensor

torchref.base.math_torch.iso_structure_factor_torched_no_complex(hkl, s, fractional_coords, occ, scattering_factors, tempfactor, space_group)[source]

Calculate isotropic structure factors without complex numbers.

Returns real and imaginary parts as separate rows.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s (torch.Tensor) – Scattering vector magnitudes of shape (N_reflections,).

  • fractional_coords (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor) – Atomic scattering factors of shape (N_reflections, N_atoms).

  • tempfactor (torch.Tensor) – Isotropic temperature factors (B-factors) of shape (N_atoms,).

  • space_group (callable) – Space group symmetry operator function.

Returns:

Structure factors as [real, imag] of shape (2, N_reflections).

Return type:

torch.Tensor

torchref.base.math_torch.aniso_structure_factor_torched(hkl, s_vector, xyz_fractional, occ, scattering_factors, U, spacegroup, max_memory_gb=None, A=None, B_coeff=None)[source]

Calculate anisotropic structure factors using PyTorch.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s_vector (torch.Tensor) – Scattering vectors of shape (N_reflections, 3).

  • xyz_fractional (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor or None) – Atomic scattering factors of shape (N_reflections, N_atoms). If None, must provide A and B_coeff to compute them in batches.

  • U (torch.Tensor) – Anisotropic displacement parameters of shape (N_atoms, 6).

  • spacegroup (callable) – Space group symmetry operator function.

  • max_memory_gb (float, optional) – Maximum memory to use in GB. If None, no batching is applied.

  • A (torch.Tensor, optional) – ITC92 A coefficients (N_atoms, 5) for computing scattering factors.

  • B_coeff (torch.Tensor, optional) – ITC92 B coefficients (N_atoms, 5) for computing scattering factors.

Returns:

Complex structure factors of shape (N_reflections,).

Return type:

torch.Tensor

torchref.base.math_torch.aniso_structure_factor_torched_no_complex(hkl, s_vector, fractional_coords, occ, scattering_factors, U, space_group)[source]

Calculate anisotropic structure factors without complex numbers.

Returns real and imaginary parts as separate rows.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • s_vector (torch.Tensor) – Scattering vectors of shape (N_reflections, 3).

  • fractional_coords (torch.Tensor) – Fractional coordinates of shape (N_atoms, 3).

  • occ (torch.Tensor) – Occupancies of shape (N_atoms,).

  • scattering_factors (torch.Tensor) – Atomic scattering factors of shape (N_reflections, N_atoms).

  • U (torch.Tensor) – Anisotropic displacement parameters of shape (N_atoms, 6).

  • space_group (callable) – Space group symmetry operator function.

Returns:

Structure factors as [real, imag] of shape (2, N_reflections).

Return type:

torch.Tensor

torchref.base.math_torch.anharmonic_correction(hkl, c)[source]

Apply anharmonic (third-order) correction to structure factors.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • c (tuple or list) – Ten anharmonic coefficients (C111, C222, C333, C112, C122, C113, C133, C223, C233, C123).

Returns:

Complex anharmonic correction factors of shape (N_reflections,).

Return type:

torch.Tensor

torchref.base.math_torch.anharmonic_correction_no_complex(hkl, c)[source]

Apply anharmonic (third-order) correction without complex numbers.

Returns real and imaginary parts as separate rows.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N_reflections, 3).

  • c (tuple or list) – Ten anharmonic coefficients (C111, C222, C333, C112, C122, C113, C133, C223, C233, C123).

Returns:

Correction factors as [cos, sin] of shape (2, N_reflections).

Return type:

torch.Tensor

torchref.base.math_torch.core_deformation(core_correction, s)[source]

Apply core deformation correction to scattering.

Parameters:
Returns:

Core deformation correction factors.

Return type:

torch.Tensor

torchref.base.math_torch.multiplication_quasi_complex_tensor(a, b)[source]

Multiply two quasi-complex tensors represented as [real, imag] rows.

Parameters:
  • a (torch.Tensor) – First quasi-complex tensor of shape (2, N).

  • b (torch.Tensor) – Second quasi-complex tensor of shape (2, N).

Returns:

Product as [real, imag] of shape (2, N).

Return type:

torch.Tensor

torchref.base.math_torch.vectorized_add_to_map(surrounding_coords, voxel_indices, density_map, xyz, b, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Add atoms to density map using ITC92 Gaussian parameterization.

Automatically selects the optimal implementation based on device. GPU default: Triton fused kernel (3-6x faster, falls back to JIT if Triton is unavailable). Override with TORCHREF_ATOM_PLACEMENT_GPU_MODE=jit or simple.

Parameters:
  • surrounding_coords (torch.Tensor) – Cartesian coordinates of voxels, shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map, shape (N_atoms, N_voxels, 3).

  • density_map (torch.Tensor) – Electron density map to update, shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates, shape (N_atoms, 3).

  • b (torch.Tensor) – Isotropic B-factors, shape (N_atoms,).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix, shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix, shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients, shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients, shape (N_atoms, 5).

  • occ (torch.Tensor) – Atomic occupancies, shape (N_atoms,).

Returns:

Updated electron density map (modified in-place).

Return type:

torch.Tensor

torchref.base.math_torch.vectorized_add_to_map_aniso(surrounding_coords, voxel_indices, map, xyz, U, inv_frac_matrix, frac_matrix, A, B, occ)[source]

Add anisotropic atoms to density map using ITC92 Gaussian parameterization.

Uses the same convention as the isotropic case for consistency: - B_total = (B_itc92 + B_atomic) / 4 - rho = A × (π/B_total)^(3/2) × exp(-π² r² / B_total)

For anisotropic atoms, this generalizes to: - B_atomic_ij = 8π² × U_atomic_ij (standard crystallographic conversion) - B_total_ij = (B_itc92 × δ_ij + 8π² × U_atomic_ij) / 4 - Normalization: (π³ / det(B_total))^(1/2) - Exponent: exp(-π² × r^T × B_total^(-1) × r)

Parameters:
  • surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • map (torch.Tensor) – Electron density map of shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions in Cartesian coordinates of shape (N_atoms, 3).

  • U (torch.Tensor) – Anisotropic displacement parameters in Angstroms squared (u11, u22, u33, u12, u13, u23) of shape (N_atoms, 6).

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

  • A (torch.Tensor) – ITC92 amplitude coefficients for each atom of shape (N_atoms, 5).

  • B (torch.Tensor) – ITC92 width coefficients (b parameters) in Angstroms squared for each atom of shape (N_atoms, 5).

  • occ (torch.Tensor) – Occupancies for each atom of shape (N_atoms,).

Returns:

Updated electron density map.

Return type:

torch.Tensor

torchref.base.math_torch.scatter_add_nd(source, index, map)[source]

Vectorized n-dimensional scatter add operation.

Parameters:
  • source (torch.Tensor) – Values to add to the map of shape (N,).

  • index (torch.Tensor) – Indices where values should be added of shape (N, ndim).

  • map (torch.Tensor) – N-dimensional tensor of shape (d1, d2, …, dn) to add values into.

Returns:

Modified map with values added.

Return type:

torch.Tensor

torchref.base.math_torch.scatter_add_nd_super_slow(source, index, map)[source]

Non-vectorized n-dimensional scatter add operation (slow reference implementation).

Parameters:
  • source (torch.Tensor) – Values to add to the map of shape (N,).

  • index (torch.Tensor) – Indices where values should be added of shape (N, ndim).

  • map (torch.Tensor) – N-dimensional tensor to add values into.

Returns:

Modified map with values added.

Return type:

torch.Tensor

torchref.base.math_torch.find_relevant_voxels(real_space_grid, xyz, radius_angstrom=4, inv_frac_matrix=None)[source]

Identify surrounding voxels of atoms in a real space grid.

This is a vectorized function that finds all voxels within a spherical radius around each atom position.

Parameters:
  • real_space_grid (torch.Tensor) – Real space grid containing xyz coordinates at each grid point, of shape (nx, ny, nz, 3).

  • xyz (torch.Tensor) – Atom coordinates in real space (Cartesian coordinates), of shape (N, 3) or (3,).

  • radius_angstrom (float, optional) – Radius around each atom in Angstroms. Default is 4.

  • inv_frac_matrix (torch.Tensor, optional) – Matrix to convert Cartesian to fractional coordinates of shape (3, 3). Required for proper handling of non-orthogonal cells.

Returns:

  • surrounding_coords (torch.Tensor) – Coordinates of surrounding voxels for each atom of shape (N, R, 3), where R is the number of voxels within the radius.

  • voxel_indices_wrapped (torch.Tensor) – Wrapped voxel indices of shape (N, R, 3).

Notes

Atom coordinates are NOT wrapped here - periodic boundary conditions are handled in smallest_diff() which finds the minimum image distance. We only wrap voxel indices to ensure they’re valid array indices.

torchref.base.math_torch.excise_angstrom_radius_around_coord(real_space_grid, start_indices, radius_angstrom=4.0)[source]

Identify voxel indices within an Angstrom radius around specified grid positions.

Parameters:
  • real_space_grid (torch.Tensor) – Real space grid of shape (nx, ny, nz, 3) containing xyz coordinates.

  • start_indices (torch.Tensor) – Starting grid indices of shape (N, 3) or (3,).

  • radius_angstrom (float, optional) – Radius in Angstroms. Default is 4.0.

Returns:

Wrapped voxel indices of shape (N, R, 3), where R is the number of voxels within the radius.

Return type:

torch.Tensor

Notes

Periodic boundary conditions are handled by wrapping the indices to ensure they’re valid array indices.

torchref.base.math_torch.add_to_solvent_mask(surrounding_coords, voxel_indices, mask, xyz, radius, inv_frac_matrix, frac_matrix)[source]

Create solvent mask by placing spheres around atom positions.

Parameters:
  • surrounding_coords (torch.Tensor) – Coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • mask (torch.Tensor) – Solvent mask to be updated of shape (nx, ny, nz).

  • xyz (torch.Tensor) – Atom positions of shape (N_atoms, 3).

  • radius (float) – Radius of the sphere around each atom in Angstroms.

  • inv_frac_matrix (torch.Tensor) – Inverse fractionalization matrix of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractionalization matrix of shape (3, 3).

Returns:

Updated solvent mask as boolean tensor.

Return type:

torch.Tensor

torchref.base.math_torch.add_to_phenix_mask(surrounding_coords, voxel_indices, xyz, vdw_radii, solvent_radius, inv_frac_matrix, frac_matrix, grid_shape, device)[source]

Create Phenix-style three-valued mask by placing spheres around atom positions.

This is a vectorized implementation that processes all atoms and voxels at once. Creates two binary masks: - protein_mask: 1 where inside VdW radius (protein core) - boundary_mask: 1 where between VdW and VdW+solvent_radius (accessible surface)

Final three-valued mask: - 0: protein_mask == 1 (protein core) - -1: boundary_mask == 1 and protein_mask == 0 (accessible surface) - 1: both masks == 0 (bulk solvent)

Parameters:
  • surrounding_coords (torch.Tensor) – Fractional coordinates of voxels around each atom of shape (N_atoms, N_voxels, 3).

  • voxel_indices (torch.Tensor) – Grid indices of voxels in the map of shape (N_atoms, N_voxels, 3).

  • xyz (torch.Tensor) – Atom positions in fractional coordinates of shape (N_atoms, 3).

  • vdw_radii (torch.Tensor) – VdW radius for each atom in Angstroms of shape (N_atoms,).

  • solvent_radius (float) – Probe radius in Angstroms (added to VdW to get accessible surface).

  • inv_frac_matrix (torch.Tensor) – Inverse fractional matrix for distance calculations of shape (3, 3).

  • frac_matrix (torch.Tensor) – Fractional matrix for distance calculations of shape (3, 3).

  • grid_shape (tuple) – Shape of the output mask (nx, ny, nz).

  • device (torch.device) – Device for tensor operations.

Returns:

  • protein_mask (torch.Tensor) – Boolean mask for protein core of shape grid_shape.

  • boundary_mask (torch.Tensor) – Boolean mask for accessible surface of shape grid_shape.

torchref.base.math_torch.find_solvent_voids(mask, periodic=True)[source]

Identify void regions in a 3D boolean tensor using connected component analysis.

A void is defined as a connected region of False values (solvent). With periodic boundary conditions, voids can wrap around the edges of the array (like in a crystallographic unit cell). Without periodic boundaries, only enclosed voids are detected.

Parameters:
  • mask (torch.Tensor or numpy.ndarray) – Boolean tensor of shape (nx, ny, nz) where True indicates solid regions (e.g., protein) and False indicates empty regions (e.g., solvent). Can be either PyTorch tensor or NumPy array.

  • periodic (bool, optional) – If True, apply periodic boundary conditions (voids can wrap around edges). If False, only detect voids that are completely enclosed and don’t touch the boundaries. Default is True.

Returns:

Dictionary where keys are int volumes (number of voxels) of each void in the original array, and values are boolean masks (torch.Tensor or numpy.ndarray) of same shape as input with True only for that specific void region. Returns an empty dict if no voids are found.

Return type:

dict

Examples

import torch
# Create a simple 5x5x5 grid with a void in the center
mask = torch.ones(5, 5, 5, dtype=torch.bool)
mask[2, 2, 2] = False  # Single void voxel
voids = find_solvent_voids(mask)
print(voids)

{1: tensor([[[False, False, …]], dtype=torch.bool)}

Notes

  • Uses scipy.ndimage.label for connected component analysis.

  • Connectivity is 26-connected (face, edge, and corner neighbors).

  • With periodic=True, the array is padded by wrapping to detect cross-boundary voids.

  • Performance is O(n) where n is the total number of voxels.

  • With periodic boundaries, large percolating voids are still detected.

torchref.base.math_torch.fft(reciprocal_grid, volume=None)[source]

Perform FFT to obtain real space electron density.

Uses fftn with norm=”forward” to match crystallographic sign convention directly, avoiding expensive flip/roll operations.

Crystallographic convention: ρ(r) = (1/V) Σ F(h) exp(-2πi h·r)

PyTorch fftn with norm=”forward” gives:

fftn(x)[n] = (1/N) Σ_k x[k] exp(-2πi k·n/N)

When input structure factors F are correctly scaled (with V/N factor from ifft), we need to multiply by N/V to recover the original electron density:

ρ = fftn(F) * (N / V)

Parameters:
  • reciprocal_grid (torch.Tensor) – Reciprocal space grid of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz). Expected to contain correctly scaled structure factors (from ifft with volume).

  • volume (float, optional) – Unit cell volume in ų. If provided, result is scaled by N/V to give correctly normalized electron density.

Returns:

Real-valued tensor of electron density with same shape as input.

Return type:

torch.Tensor

torchref.base.math_torch.ifft(real_space_map, volume=None)[source]

Perform inverse FFT to obtain reciprocal space structure factors.

Crystallographic convention: F(h) = Σ ρ(r) exp(+2πi h·r) * ΔV where ΔV = V_cell / N is the voxel volume.

PyTorch ifftn with norm=”forward” gives unnormalized DFT:

DFT[k] = Σ x[n] exp(+2πi k·n/N)

To obtain correctly scaled structure factors, we multiply by voxel volume:

F(h) = DFT(ρ) * (V_cell / N)

Parameters:
  • real_space_map (torch.Tensor) – Real space electron density map of shape (Nx, Ny, Nz) or (B, Nx, Ny, Nz).

  • volume (float, optional) – Unit cell volume in ų. If provided, result is scaled by voxel volume (V_cell / N_total) to give correctly normalized structure factors.

Returns:

Complex-valued tensor of structure factors with same shape as input.

Return type:

torch.Tensor

torchref.base.math_torch.get_real_grid(cell=None, fractional_matrix=None, max_res=0.8, gridsize=None, device=None)[source]

Generate a real space grid for electron density calculations.

Parameters:
  • cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • fractional_matrix (torch.Tensor, optional) – Pre-computed fractionalization matrix.

  • max_res (float, optional) – Maximum resolution for automatic grid sizing. Default is 0.8.

  • gridsize (torch.Tensor or array-like, optional) – Explicit grid dimensions [nx, ny, nz]. If None, calculated from max_res.

  • device (torch.device or str, optional) – Device for tensor placement. If None, inferred from fractional_matrix or cell (whichever tensor is provided); falls back to CPU.

Returns:

Real space grid of shape (nx, ny, nz, 3) containing Cartesian coordinates.

Return type:

torch.Tensor

torchref.base.math_torch.find_grid_size(cell, max_res)[source]

Calculate grid size based on unit cell and resolution.

Parameters:
  • cell (torch.Tensor) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • max_res (float) – Maximum resolution in Angstroms.

Returns:

Grid dimensions [nx, ny, nz] as int32.

Return type:

torch.Tensor

torchref.base.math_torch.rotate_coords_torch(coords, phi, rho)[source]

Rotate coordinates using phi and rho angles (PyTorch version).

Parameters:
  • coords (torch.Tensor) – Coordinates of shape (N, 3) to rotate.

  • phi (float) – Rotation angle phi in degrees.

  • rho (float) – Rotation angle rho in degrees.

Returns:

Rotated coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.axis_angle_to_rotation_matrix(axis_angle)[source]

Convert axis-angle representation to 3x3 rotation matrix.

Uses Rodrigues’ formula. Supports batched input and gradients.

Parameters:

axis_angle (torch.Tensor) – Axis-angle representation with shape (3,) or (N, 3). Direction is rotation axis, magnitude is angle in radians.

Returns:

Rotation matrix with shape (3, 3) or (N, 3, 3).

Return type:

torch.Tensor

torchref.base.math_torch.rotation_matrix_to_axis_angle(R)[source]

Convert 3x3 rotation matrix to axis-angle representation.

Parameters:

R (torch.Tensor) – Rotation matrix with shape (3, 3) or (N, 3, 3).

Returns:

Axis-angle representation with shape (3,) or (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.quaternion_to_rotation_matrix(q)[source]

Convert quaternion to rotation matrix.

Parameters:

q (torch.Tensor) – Quaternion with shape (4,) or (N, 4). Format: [w, x, y, z].

Returns:

Rotation matrix with shape (3, 3) or (N, 3, 3).

Return type:

torch.Tensor

torchref.base.math_torch.random_rotation_uniform(n=1, device=device(type='cpu'), dtype=torch.float32)[source]

Generate uniform random rotations over SO(3).

Uses Shoemake’s quaternion-based uniform sampling.

Parameters:
  • n (int, optional) – Number of rotations to generate. Default is 1.

  • device (str, optional) – Device for output tensor. Defaults to the configured device.current.

  • dtype (torch.dtype, optional) – Data type for output tensor. Default is dtypes.float.

Returns:

Rotation matrices with shape (n, 3, 3) or (3, 3) if n=1.

Return type:

torch.Tensor

torchref.base.math_torch.superpose_vectors_robust_torch(ref_coords, mov_coords, weights=None, max_iterations=10)[source]

Perform weighted superposition of two coordinate sets using SVD (PyTorch version).

Parameters:
  • ref_coords (torch.Tensor) – Reference coordinates of shape (N, 3).

  • mov_coords (torch.Tensor) – Mobile coordinates of shape (N, 3) to be superposed onto reference.

  • weights (torch.Tensor, optional) – Weights for each atom of shape (N, 1). Default is uniform weights.

  • max_iterations (int, optional) – Maximum number of iterations for refinement. Default is 10.

Returns:

4x4 transformation matrix (shape (3, 4) returned).

Return type:

torch.Tensor

torchref.base.math_torch.align_torch(xyz1, xyz2, idx_to_move=None)[source]

Align two coordinate sets using superposition (PyTorch version).

Parameters:
  • xyz1 (torch.Tensor) – Target coordinates of shape (N, 3).

  • xyz2 (torch.Tensor) – Coordinates to be aligned of shape (N, 3).

  • idx_to_move (torch.Tensor, optional) – Indices of atoms to use for alignment. If None, uses all atoms.

Returns:

Aligned coordinates of shape (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.get_alignement_matrix(xyz1, xyz2, idx_to_move=None)[source]

Get the alignment transformation matrix between two coordinate sets.

Parameters:
  • xyz1 (torch.Tensor) – Target coordinates of shape (N, 3).

  • xyz2 (torch.Tensor) – Coordinates to be aligned of shape (N, 3).

  • idx_to_move (torch.Tensor, optional) – Indices of atoms to use for alignment. If None, uses all atoms.

Returns:

Transformation matrix of shape (3, 4).

Return type:

torch.Tensor

torchref.base.math_torch.apply_transformation(points, transformation_matrix)[source]

Apply a 4x4 transformation matrix to 3D points (PyTorch version).

Parameters:
  • points (torch.Tensor) – 3D points of shape (N, 3).

  • transformation_matrix (torch.Tensor) – Transformation matrix of shape (3, 4) or (4, 4).

Returns:

Transformed 3D points of shape (N, 3).

Return type:

torch.Tensor

torchref.base.math_torch.get_rfactor_torch(F_obs, F_calc)[source]

Calculate R-factor between observed and calculated structure factors (PyTorch version).

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes.

Returns:

R-factor value.

Return type:

torch.Tensor

torchref.base.math_torch.rfactor(F_obs, F_calc)[source]

Calculate R-factor between observed and calculated structure factors.

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes of shape (N,).

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes of shape (N,).

Returns:

R-factor value.

Return type:

float

torchref.base.math_torch.get_rfactors(F_obs, F_calc, rfree)[source]

Get R-factors for working and test sets.

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes of shape (N,).

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes of shape (N,).

  • rfree (torch.Tensor) – Boolean mask indicating R-free reflections of shape (N,). 1 is Working set, 0 is Test set.

Returns:

(r_work, r_test) where r_work is the R-factor for the working set and r_test is the R-factor for the test set.

Return type:

tuple

torchref.base.math_torch.bin_wise_rfactors(F_obs, F_calc, rfree, bins)[source]

Calculate bin-wise R-factors between observed and calculated structure factors.

Parameters:
Returns:

  • r_work_bins (torch.Tensor) – R-factors for working set (per bin).

  • r_test_bins (torch.Tensor) – R-factors for test set (per bin).

Return type:

tuple

torchref.base.math_torch.calc_outliers(F_obs, F_calc, z)[source]

Identify outlier reflections based on deviation from expected values (PyTorch version).

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factor amplitudes.

  • z (float) – Number of standard deviations for outlier threshold.

Returns:

Boolean mask where True indicates outlier reflections.

Return type:

torch.Tensor

torchref.base.math_torch.nll_xray(F_obs, F_calc, sigma_F_obs)[source]

Compute X-ray negative log-likelihood assuming Gaussian distribution.

Parameters:
  • F_obs (torch.Tensor or MaskedTensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor or MaskedTensor) – Standard deviations of observed amplitudes.

Returns:

Mean negative log-likelihood.

Return type:

torch.Tensor

torchref.base.math_torch.nll_xray_sum(F_obs, F_calc, sigma_F_obs)[source]

Compute summed X-ray negative log-likelihood assuming Gaussian distribution.

Parameters:
  • F_obs (torch.Tensor or MaskedTensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor or MaskedTensor) – Standard deviations of observed amplitudes.

Returns:

Sum of negative log-likelihood values.

Return type:

torch.Tensor

torchref.base.math_torch.nll_xray_lognormal(F_obs, F_calc, sigma_F_obs, eps=1e-10)[source]

Compute X-ray negative log-likelihood assuming lognormal distribution.

This is a more realistic model for structure factor amplitudes, which must be positive. For a lognormal distribution LogNormal(mu, sigma^2), the NLL is: NLL = 0.5*(log(x) - mu)^2/sigma^2 + log(x) + log(sigma) + 0.5*log(2*pi)

Where mu and sigma are derived from F_obs and sigma_F_obs using: - sigma = sqrt(log(1 + (sigma_F/F)^2)) - mu = log(F) - sigma^2/2

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor) – Standard deviations of observed amplitudes.

  • eps (float, optional) – Small value to avoid numerical issues. Default is 1e-10.

Returns:

Mean negative log-likelihood.

Return type:

torch.Tensor

torchref.base.math_torch.log_loss(F_obs, F_calc, sigma_F_obs)[source]

Compute log-space loss between observed and calculated structure factors.

Parameters:
  • F_obs (torch.Tensor) – Observed structure factor amplitudes.

  • F_calc (torch.Tensor) – Calculated structure factors (complex).

  • sigma_F_obs (torch.Tensor) – Standard deviations of observed amplitudes (unused).

Returns:

Mean absolute difference in log space.

Return type:

torch.Tensor

torchref.base.math_torch.estimate_sigma_I(I)[source]

Estimate standard deviation of intensities.

Separates positive and negative values for robust estimation.

Parameters:

I (torch.Tensor) – Intensity values.

Returns:

Estimated standard deviations.

Return type:

torch.Tensor

torchref.base.math_torch.estimate_sigma_F(F)[source]

Estimate standard deviation of structure factor amplitudes.

Parameters:

F (torch.Tensor) – Structure factor amplitudes.

Returns:

Estimated standard deviations.

Return type:

torch.Tensor

torchref.base.math_torch.gaussian_to_lognormal_sigma(F, sigma_F, eps=1e-10)[source]

Approximate the sigma parameter of a lognormal distribution from Gaussian statistics.

If we assume F comes from a lognormal distribution X ~ LogNormal(mu, sigma^2), then: - Mean: E[X] = F - Std: sqrt(Var[X]) = sigma_F

For lognormal distribution: - E[X] = exp(mu + sigma^2/2) - Var(X) = exp(2*mu + sigma^2)(exp(sigma^2) - 1)

We can derive: - CV^2 = Var[X]/E[X]^2 = exp(sigma^2) - 1 - sigma = sqrt(log(1 + CV^2))

where CV = sigma_F/F is the coefficient of variation.

Parameters:
  • F (torch.Tensor) – Structure factor amplitudes (mean of the distribution).

  • sigma_F (torch.Tensor) – Standard deviations.

  • eps (float, optional) – Small value to avoid division by zero. Default is 1e-10.

Returns:

Sigma parameter for lognormal distribution.

Return type:

torch.Tensor

torchref.base.math_torch.gaussian_to_lognormal_mu(F, sigma_lognormal, eps=1e-10)[source]

Calculate the mu parameter of a lognormal distribution given F and sigma.

For lognormal distribution X ~ LogNormal(mu, sigma^2): - E[X] = exp(mu + sigma^2/2)

Solving for mu: - mu = log(E[X]) - sigma^2/2

Parameters:
  • F (torch.Tensor) – Structure factor amplitudes (mean of the distribution).

  • sigma_lognormal (torch.Tensor) – Sigma parameter from lognormal distribution.

  • eps (float, optional) – Small value to avoid log of zero. Default is 1e-10.

Returns:

Mu parameter for lognormal distribution.

Return type:

torch.Tensor

torchref.base.math_torch.U_to_matrix(U)[source]

Convert anisotropic displacement parameters from 6-component vector to 3x3 matrix.

Parameters:

U (torch.Tensor) – Anisotropic displacement parameters in the order [u11, u22, u33, u12, u13, u23] of shape (…, 6).

Returns:

Anisotropic displacement parameter matrices of shape (…, 3, 3).

Return type:

torch.Tensor

torchref.base.math_torch.deterministic_tensor_digest(t, n_chunks=16)[source]

Compute a deterministic digest vector for tensor directly on GPU.

This function is deterministic across devices and runs, sensitive to all tensor values and order, fully vectorized with no Python loops, and suitable for large GPU tensors. Uses a simple mean/std approach per chunk which is fully deterministic.

Parameters:
  • t (torch.Tensor) – Input tensor to compute digest for.

  • n_chunks (int, optional) – Number of chunks to divide the tensor into. Default is 16.

Returns:

Digest vector of length n_chunks.

Return type:

torch.Tensor

torchref.base.math_torch.hash_tensors(tensors)[source]

Compute a hash of multiple tensors for caching purposes.

Deprecated since version Use: (tensor.data_ptr(), tensor._version, tensor.numel()) tuples for lightweight fingerprinting instead. hash_tensors copies data to CPU and computes SHA-1, which is expensive.

Parameters:

tensors (list of torch.Tensor or None) – List of tensors to hash. None values are handled.

Returns:

SHA-1 hash of the tensor contents.

Return type:

str

torchref.base.math_torch.french_wilson_conversion(Iobs, sigma_I=None)[source]

Convert intensities to structure factor amplitudes using French-Wilson method.

Also converts standard deviations.

Parameters:
  • Iobs (torch.Tensor) – Observed intensity values.

  • sigma_I (torch.Tensor, optional) – Estimated standard deviations of intensities.

Returns:

  • F (torch.Tensor) – Structure factor amplitudes.

  • sigma_F (torch.Tensor) – Standard deviations of structure factor amplitudes.