Source code for torchref.symmetry.cell

"""
Cell - A dataclass for crystallographic unit cells with cached derived quantities.

Provides a simple container for unit cell parameters with automatic caching
of derived quantities (fractional matrix, volume, etc.).
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

import torch

from torchref.config import get_default_device, get_float_dtype
from torchref.utils.device_mixin import _NonModuleDeviceMixin


[docs] @dataclass class Cell(_NonModuleDeviceMixin): """ Dataclass for crystallographic unit cells with cached derived quantities. Stores 6 parameters: [a, b, c, alpha, beta, gamma] - a, b, c: cell lengths in Angstroms - alpha, beta, gamma: cell angles in degrees Derived quantities (fractional_matrix, volume, etc.) are computed on first access and cached. The cache is cleared when the cell is moved to a different device or dtype. Examples -------- >>> cell = Cell([50, 60, 70, 90, 90, 90]) >>> cell.volume # Computed and cached tensor(210000.) >>> cell_gpu = cell.to('cuda') # Move to GPU (returns new Cell) >>> cell_gpu.device.type 'cuda' """ _data: torch.Tensor _cache: dict = field(default_factory=dict, repr=False)
[docs] def __init__( self, data: Any, *, dtype: torch.dtype = get_float_dtype(), device: torch.device | str = get_default_device(), requires_grad: bool = False, ) -> None: """ Create a new Cell. Parameters ---------- data : array-like Unit cell parameters [a, b, c, alpha, beta, gamma]. Can be a list, numpy array, or torch tensor. dtype : torch.dtype, optional Desired data type. Defaults to the configured ``dtypes.float``. device : torch.device or str, optional Desired device. Defaults to the configured ``device.current``. requires_grad : bool, optional Whether to track gradients. Defaults to False. Raises ------ ValueError If data does not have exactly 6 elements. """ # Convert to tensor first to get shape if isinstance(data, torch.Tensor): tensor = data.to(dtype=dtype, device=device) else: tensor = torch.tensor(data, dtype=dtype, device=device) # Validate shape if tensor.numel() != 6: raise ValueError( f"Cell requires exactly 6 elements [a, b, c, alpha, beta, gamma], " f"got {tensor.numel()}" ) # Ensure 1D shape tensor = tensor.reshape(6) if requires_grad: tensor = tensor.requires_grad_(True) object.__setattr__(self, "_data", tensor) object.__setattr__(self, "_cache", {})
# ========================================================================= # Device/dtype movement methods # ========================================================================= # ``to``, ``cuda``, ``cpu`` are inherited from ``_NonModuleDeviceMixin`` # and operate in place — they walk ``self.__dict__`` (moving ``_data`` # and any cached tensor values) and then call ``reset_cache`` below.
[docs] def reset_cache(self) -> None: """Clear cached derived quantities (fractional matrix, volume, etc.).""" object.__setattr__(self, "_cache", {})
[docs] def detach(self) -> "Cell": """ Return a new Cell with detached tensor (no gradient tracking). Returns ------- Cell New Cell with detached data. """ new_data = self._data.detach() new_cell = Cell.__new__(Cell) object.__setattr__(new_cell, "_data", new_data) object.__setattr__(new_cell, "_cache", {}) return new_cell
[docs] def clone(self) -> "Cell": """ Return a new Cell with cloned tensor data. Returns ------- Cell New Cell with cloned data. """ new_data = self._data.clone() new_cell = Cell.__new__(Cell) object.__setattr__(new_cell, "_data", new_data) object.__setattr__(new_cell, "_cache", {}) return new_cell
# ========================================================================= # Basic properties # ========================================================================= @property def device(self) -> torch.device: """Return the device of the underlying tensor.""" return self._data.device @property def dtype(self) -> torch.dtype: """Return the dtype of the underlying tensor.""" return self._data.dtype @property def data(self) -> torch.Tensor: """Return the underlying tensor (for buffer registration).""" return self._data @property def requires_grad(self) -> bool: """Return whether gradients are tracked.""" return self._data.requires_grad # ========================================================================= # Convenience properties for cell parameters # ========================================================================= @property def a(self) -> torch.Tensor: """Cell length a in Angstroms.""" return self._data[0] @property def b(self) -> torch.Tensor: """Cell length b in Angstroms.""" return self._data[1] @property def c(self) -> torch.Tensor: """Cell length c in Angstroms.""" return self._data[2] @property def alpha(self) -> torch.Tensor: """Cell angle alpha in degrees.""" return self._data[3] @property def beta(self) -> torch.Tensor: """Cell angle beta in degrees.""" return self._data[4] @property def gamma(self) -> torch.Tensor: """Cell angle gamma in degrees.""" return self._data[5] # ========================================================================= # Cached derived properties # ========================================================================= @property def fractional_matrix(self) -> torch.Tensor: """ Orthogonalization matrix B (fractional -> Cartesian). Returns the 3x3 matrix B such that: cart = frac @ B.T Returns ------- torch.Tensor Shape (3, 3) orthogonalization matrix. """ if "fractional_matrix" not in self._cache: self._cache["fractional_matrix"] = self._compute_fractional_matrix() return self._cache["fractional_matrix"] @property def inv_fractional_matrix(self) -> torch.Tensor: """ Fractionalization matrix B^-1 (Cartesian -> fractional). Returns the 3x3 matrix B^-1 such that: frac = cart @ B^-1.T Returns ------- torch.Tensor Shape (3, 3) fractionalization matrix. """ if "inv_fractional_matrix" not in self._cache: self._cache["inv_fractional_matrix"] = torch.linalg.inv( self.fractional_matrix ) return self._cache["inv_fractional_matrix"] @property def volume(self) -> torch.Tensor: """ Unit cell volume in cubic Angstroms. Returns ------- torch.Tensor Scalar tensor with the cell volume. """ if "volume" not in self._cache: self._cache["volume"] = self._compute_volume() return self._cache["volume"] @property def reciprocal_basis_matrix(self) -> torch.Tensor: """ Reciprocal basis matrix with [a*, b*, c*] as rows. Returns ------- torch.Tensor Shape (3, 3) matrix where rows are the reciprocal basis vectors. """ if "reciprocal_basis_matrix" not in self._cache: self._cache["reciprocal_basis_matrix"] = ( self._compute_reciprocal_basis_matrix() ) return self._cache["reciprocal_basis_matrix"] # ========================================================================= # Internal computation methods # ========================================================================= def _compute_fractional_matrix(self) -> torch.Tensor: """ Compute the fractional-to-Cartesian transformation matrix. Delegates to math_numpy.get_fractional_matrix for the computation. """ from torchref.base import math_torch return math_torch.get_fractional_matrix(self._data) def _compute_volume(self) -> torch.Tensor: """ Compute the unit cell volume. Uses the formula: V = abc * sqrt(1 - cos^2(alpha) - cos^2(beta) - cos^2(gamma) + 2*cos(alpha)*cos(beta)*cos(gamma)) """ a, b, c = self._data[0], self._data[1], self._data[2] angles_rad = torch.deg2rad(self._data[3:]) cos_alpha, cos_beta, cos_gamma = torch.cos(angles_rad) volume_factor = torch.sqrt( 1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 + 2 * cos_alpha * cos_beta * cos_gamma ) return a * b * c * volume_factor def _compute_reciprocal_basis_matrix(self) -> torch.Tensor: """ Compute the reciprocal space basis matrix. Delegates to math_torch.reciprocal_basis_matrix for the computation. """ from torchref.base import math_torch return math_torch.reciprocal_basis_matrix(self._data) # ========================================================================= # Grid computation methods # =========================================================================
[docs] def compute_grid_size(self, max_res: float, oversampling: float = 3.0) -> tuple: """ Compute minimum grid dimensions for a given resolution. Uses Shannon-Nyquist sampling criterion to determine the minimum number of grid points needed along each axis. Parameters ---------- max_res : float Maximum resolution in Angstroms. oversampling : float, optional Oversampling factor relative to max_res. Default is 3.0 (standard for crystallographic calculations). Returns ------- tuple of int Minimum grid dimensions (nx, ny, nz). Examples -------- >>> cell = Cell([50, 60, 70, 90, 90, 90]) >>> cell.compute_grid_size(2.0) (75, 90, 105) """ import math a, b, c = self.a.item(), self.b.item(), self.c.item() # Shannon-Nyquist: sample at oversampling × the maximum frequency nx = int(math.floor(a / max_res * oversampling)) ny = int(math.floor(b / max_res * oversampling)) nz = int(math.floor(c / max_res * oversampling)) return (nx, ny, nz)
[docs] def tolist(self) -> list: """ Convert Cell parameters to a standard Python list. Returns ------- list List of cell parameters [a, b, c, alpha, beta, gamma]. """ return self._data.tolist()
# ========================================================================= # Fractional/Cartesian conversion methods # =========================================================================
[docs] def fractional_to_cartesian(self, frac_coords: torch.Tensor) -> torch.Tensor: """ Convert fractional coordinates to Cartesian coordinates. Parameters ---------- frac_coords : torch.Tensor Tensor of fractional coordinates, shape (..., 3). Returns ------- torch.Tensor Tensor of Cartesian coordinates, shape (..., 3). """ return torch.matmul(frac_coords, self.fractional_matrix.T)
[docs] def cartesian_to_fractional(self, cart_coords: torch.Tensor) -> torch.Tensor: """ Convert Cartesian coordinates to fractional coordinates. Parameters ---------- cart_coords : torch.Tensor Tensor of Cartesian coordinates, shape (..., 3). Returns ------- torch.Tensor Tensor of fractional coordinates, shape (..., 3). """ return torch.matmul(cart_coords, self.inv_fractional_matrix.T)
# ========================================================================= # Dunder methods # =========================================================================
[docs] def __repr__(self) -> str: """Return string representation.""" return ( f"Cell([{self.a.item():.4f}, {self.b.item():.4f}, {self.c.item():.4f}, " f"{self.alpha.item():.4f}, {self.beta.item():.4f}, {self.gamma.item():.4f}], " f"device={self.device}, dtype={self.dtype})" )
[docs] def __getitem__(self, idx: int) -> torch.Tensor: """Allow indexing like cell[0] for cell length a.""" return self._data[idx]
[docs] def __len__(self) -> int: """Return 6 (number of cell parameters).""" return 6
# Keep CellTensor as an alias for backward compatibility CellTensor = Cell