"""
PyTorch implementations of mathematical functions for crystallography.
.. deprecated::
This module is deprecated. Please import from domain-specific submodules:
- ``torchref.base.coordinates`` - Coordinate transformations
- ``torchref.base.reciprocal`` - Reciprocal space calculations
- ``torchref.base.direct_summation`` - Structure factor calculations
- ``torchref.base.electron_density`` - Electron density map building
- ``torchref.base.fourier`` - FFT operations
- ``torchref.base.scattering`` - Atomic scattering factors
- ``torchref.base.alignment`` - Coordinate alignment
- ``torchref.base.metrics`` - R-factors and loss functions
- ``torchref.base.kernels`` - Optimized kernels
This module is maintained for backward compatibility and re-exports all
functions from the new submodules.
Example (new style - recommended)::
from torchref.base.coordinates import cartesian_to_fractional_torch
from torchref.base.metrics import get_rfactor_torch
Example (old style - still works)::
from torchref.base.math_torch import cartesian_to_fractional_torch
"""
import hashlib
import torch
# =============================================================================
# Re-exports from coordinates submodule
# =============================================================================
from torchref.base.coordinates import (
cartesian_to_fractional_torch,
fractional_to_cartesian_torch,
get_fractional_matrix,
get_inv_fractional_matrix_torch,
smallest_diff,
smallest_diff_aniso,
)
# =============================================================================
# Re-exports from reciprocal submodule
# =============================================================================
from torchref.base.reciprocal import (
reciprocal_basis_matrix,
get_scattering_vectors,
get_d_spacing,
place_on_grid,
extract_structure_factor_from_grid,
apply_translation_phase,
interpolate_structure_factor_from_grid,
interpolate_complex_from_grid,
trilinear_interpolate_patterson,
)
# =============================================================================
# Re-exports from direct_summation submodule
# =============================================================================
from torchref.base.direct_summation import (
iso_structure_factor_torched,
iso_structure_factor_torched_no_complex,
aniso_structure_factor_torched,
aniso_structure_factor_torched_no_complex,
anharmonic_correction,
anharmonic_correction_no_complex,
core_deformation,
multiplication_quasi_complex_tensor,
)
# =============================================================================
# Re-exports from electron_density submodule
# =============================================================================
from torchref.base.electron_density import (
vectorized_add_to_map,
vectorized_add_to_map_aniso,
scatter_add_nd,
scatter_add_nd_super_slow,
find_relevant_voxels,
excise_angstrom_radius_around_coord,
add_to_solvent_mask,
add_to_phenix_mask,
find_solvent_voids,
)
# =============================================================================
# Re-exports from fourier submodule
# =============================================================================
from torchref.base.fourier import (
fft,
ifft,
get_real_grid,
find_grid_size,
)
# =============================================================================
# Re-exports from alignment submodule
# =============================================================================
from torchref.base.alignment import (
rotate_coords_torch,
axis_angle_to_rotation_matrix,
rotation_matrix_to_axis_angle,
quaternion_to_rotation_matrix,
random_rotation_uniform,
superpose_vectors_robust_torch,
align_torch,
get_alignement_matrix,
apply_transformation,
)
# =============================================================================
# Re-exports from metrics submodule
# =============================================================================
from torchref.base.metrics import (
get_rfactor_torch,
rfactor,
get_rfactors,
bin_wise_rfactors,
calc_outliers,
nll_xray,
nll_xray_sum,
nll_xray_lognormal,
log_loss,
estimate_sigma_I,
estimate_sigma_F,
gaussian_to_lognormal_sigma,
gaussian_to_lognormal_mu,
)
# =============================================================================
# Utility functions (kept here as they don't fit a specific domain)
# =============================================================================
[docs]
def U_to_matrix(U: torch.Tensor) -> torch.Tensor:
"""
Convert anisotropic displacement parameters from 6-component vector to 3x3 matrix.
Parameters
----------
U : torch.Tensor
Anisotropic displacement parameters in the order
[u11, u22, u33, u12, u13, u23] of shape (..., 6).
Returns
-------
torch.Tensor
Anisotropic displacement parameter matrices of shape (..., 3, 3).
"""
u11 = U[..., 0]
u22 = U[..., 1]
u33 = U[..., 2]
u12 = U[..., 3]
u13 = U[..., 4]
u23 = U[..., 5]
# Build rows and stack to preserve gradient flow
row0 = torch.stack([u11, u12, u13], dim=-1)
row1 = torch.stack([u12, u22, u23], dim=-1)
row2 = torch.stack([u13, u23, u33], dim=-1)
return torch.stack([row0, row1, row2], dim=-2)
[docs]
def deterministic_tensor_digest(t: torch.Tensor, n_chunks: int = 16) -> torch.Tensor:
"""
Compute a deterministic digest vector for tensor directly on GPU.
This function is deterministic across devices and runs, sensitive to all
tensor values and order, fully vectorized with no Python loops, and suitable
for large GPU tensors. Uses a simple mean/std approach per chunk which is
fully deterministic.
Parameters
----------
t : torch.Tensor
Input tensor to compute digest for.
n_chunks : int, optional
Number of chunks to divide the tensor into. Default is 16.
Returns
-------
torch.Tensor
Digest vector of length n_chunks.
"""
# Flatten and cast to a stable type
flat = t.detach().reshape(-1)
if not torch.is_floating_point(flat):
flat = flat.float()
# If tensor smaller than n_chunks, just pad
n = flat.numel()
if n < n_chunks:
flat = torch.nn.functional.pad(flat, (0, n_chunks - n))
n = n_chunks
# Reshape into chunks directly (pad if needed)
chunk_size = (n + n_chunks - 1) // n_chunks
padded_size = chunk_size * n_chunks
if n < padded_size:
flat = torch.nn.functional.pad(flat, (0, padded_size - n))
# Reshape to (n_chunks, chunk_size) and compute stats per chunk
chunks = flat[: chunk_size * n_chunks].reshape(n_chunks, chunk_size)
# Create digest from mean and std of each chunk (both are deterministic)
# Interleave for better sensitivity
digest_mean = chunks.mean(dim=1)
if chunks.size(1) > 1:
digest_std = chunks.std(dim=1)
else:
digest_std = torch.zeros_like(digest_mean)
# Combine into single digest vector (alternate mean/std would double size)
# Instead, use weighted combination
digest = digest_mean + 0.61803398875 * digest_std
return digest
[docs]
def hash_tensors(tensors) -> str:
"""
Compute a hash of multiple tensors for caching purposes.
.. deprecated::
Use ``(tensor.data_ptr(), tensor._version, tensor.numel())`` tuples
for lightweight fingerprinting instead. ``hash_tensors`` copies data
to CPU and computes SHA-1, which is expensive.
Parameters
----------
tensors : list of torch.Tensor or None
List of tensors to hash. None values are handled.
Returns
-------
str
SHA-1 hash of the tensor contents.
"""
import warnings
warnings.warn(
"hash_tensors is deprecated. Use (tensor.data_ptr(), tensor._version, "
"tensor.numel()) tuples for lightweight fingerprinting instead.",
DeprecationWarning,
stacklevel=2,
)
h = hashlib.sha1()
for t in tensors:
if t is None:
h.update(b"<None>")
continue
digest = deterministic_tensor_digest(t)
# Bring only digest (small) to CPU for hashing
h.update(digest.cpu().numpy().tobytes())
h.update(str(t.shape).encode())
h.update(str(t.dtype).encode())
return h.hexdigest()
[docs]
def french_wilson_conversion(Iobs, sigma_I=None):
"""
Convert intensities to structure factor amplitudes using French-Wilson method.
Also converts standard deviations.
Parameters
----------
Iobs : torch.Tensor
Observed intensity values.
sigma_I : torch.Tensor, optional
Estimated standard deviations of intensities.
Returns
-------
F : torch.Tensor
Structure factor amplitudes.
sigma_F : torch.Tensor
Standard deviations of structure factor amplitudes.
"""
# If no sigmas provided, estimate them
if sigma_I is None:
sigma_I = torch.sqrt(torch.clamp(Iobs, min=1e-6))
# Determine mean intensity for Wilson prior
mean_I = torch.mean(torch.clamp(Iobs[~torch.isnan(Iobs)], min=0))
# Strong reflections: simple square root
strong_mask = Iobs > 3.0 * sigma_I
F = torch.zeros_like(Iobs)
F[strong_mask] = torch.sqrt(Iobs[strong_mask])
# Weak/negative reflections: Bayesian approach
weak_mask = ~strong_mask
if weak_mask.any():
# For weak reflections, use Bayesian estimate
I_weak = Iobs[weak_mask]
sigma_weak = sigma_I[weak_mask]
# For negative intensities, we need a better prior estimate
# The global mean_I is biased toward strong reflections
# For negative I, use the uncertainty as a guide for the expected intensity
# Better prior: use sigma_I as a proxy for the true intensity scale
# Separate negative and weak positive
neg_local_mask = I_weak < 0
pos_local_mask = ~neg_local_mask
F_weak = torch.zeros_like(I_weak)
# For negative intensities: use sigma_I based correction
# The idea: if I < 0, the true intensity is likely ~sigma_I in magnitude
# So use a prior based on sigma_I rather than the global mean_I
if neg_local_mask.any():
I_neg = I_weak[neg_local_mask]
sigma_neg = sigma_weak[neg_local_mask]
# For negative intensities, use a correction proportional to sigma^2
# This gives F values that scale with the uncertainty
# Formula: F ≈ sqrt(sigma^2 / 2) for very negative
# Blend with the global prior for moderately negative
# Weight based on how negative: |I/sigma|
epsilon = torch.abs(I_neg / torch.clamp(sigma_neg, min=1e-10))
# For slightly negative (epsilon < 0.5): use small correction
# For very negative (epsilon > 1): use sigma-based prior
# Smooth transition with tanh
weight_sigma_prior = torch.tanh(epsilon)
# Sigma-based correction (for very negative)
F_sigma_prior = torch.sqrt(sigma_neg**2 / 2.0)
# Global prior correction (for slightly negative)
wilson_param_global = mean_I / 2.0
correction_global = sigma_neg**2 / (2.0 * wilson_param_global)
F_global = torch.sqrt(torch.clamp(I_neg + correction_global, min=0))
# Blend
F_weak[neg_local_mask] = (
weight_sigma_prior * F_sigma_prior + (1 - weight_sigma_prior) * F_global
)
# For weak positive intensities: use standard correction
if pos_local_mask.any():
I_pos = I_weak[pos_local_mask]
sigma_pos = sigma_weak[pos_local_mask]
# Simplified Bayesian estimate (posterior mean)
wilson_param = mean_I / 2.0
variance_correction = sigma_pos**2 / (2.0 * wilson_param)
F_weak[pos_local_mask] = torch.sqrt(
torch.clamp(I_pos + variance_correction, min=0)
)
F[weak_mask] = F_weak
# Convert sigmas using error propagation formula
# For F = sqrt(I), σ(F) = σ(I)/(2*F)
# Avoid division by zero
sigma_F = torch.zeros_like(sigma_I)
nonzero_F = F > 1e-6
sigma_F[nonzero_F] = sigma_I[nonzero_F] / (2.0 * F[nonzero_F])
# For very weak reflections where F approaches zero,
# use an upper bound approximation to avoid huge sigma values
tiny_F = (F <= 1e-6) & (sigma_I > 0)
if tiny_F.any():
# Approximate using the typical Wilson distribution variance
wilson_param = mean_I / 2.0
sigma_F[tiny_F] = torch.sqrt(wilson_param / 2.0)
return F, sigma_F
# =============================================================================
# __all__ - Public API
# =============================================================================
__all__ = [
# Coordinate transforms
"cartesian_to_fractional_torch",
"fractional_to_cartesian_torch",
"get_fractional_matrix",
"get_inv_fractional_matrix_torch",
"smallest_diff",
"smallest_diff_aniso",
# Reciprocal space
"reciprocal_basis_matrix",
"get_scattering_vectors",
"get_d_spacing",
"place_on_grid",
"extract_structure_factor_from_grid",
"apply_translation_phase",
"interpolate_structure_factor_from_grid",
"interpolate_complex_from_grid",
"trilinear_interpolate_patterson",
# Structure factors
"iso_structure_factor_torched",
"iso_structure_factor_torched_no_complex",
"aniso_structure_factor_torched",
"aniso_structure_factor_torched_no_complex",
"anharmonic_correction",
"anharmonic_correction_no_complex",
"core_deformation",
"multiplication_quasi_complex_tensor",
# Electron density
"vectorized_add_to_map",
"vectorized_add_to_map_aniso",
"scatter_add_nd",
"scatter_add_nd_super_slow",
"find_relevant_voxels",
"excise_angstrom_radius_around_coord",
"add_to_solvent_mask",
"add_to_phenix_mask",
"find_solvent_voids",
# Fourier
"fft",
"ifft",
"get_real_grid",
"find_grid_size",
# Alignment
"rotate_coords_torch",
"axis_angle_to_rotation_matrix",
"rotation_matrix_to_axis_angle",
"quaternion_to_rotation_matrix",
"random_rotation_uniform",
"superpose_vectors_robust_torch",
"align_torch",
"get_alignement_matrix",
"apply_transformation",
# Metrics
"get_rfactor_torch",
"rfactor",
"get_rfactors",
"bin_wise_rfactors",
"calc_outliers",
"nll_xray",
"nll_xray_sum",
"nll_xray_lognormal",
"log_loss",
"estimate_sigma_I",
"estimate_sigma_F",
"gaussian_to_lognormal_sigma",
"gaussian_to_lognormal_mu",
# Utility functions
"U_to_matrix",
"deterministic_tensor_digest",
"hash_tensors",
"french_wilson_conversion",
]