torchref.symmetry.cell module

Cell - A dataclass for crystallographic unit cells with cached derived quantities.

Provides a simple container for unit cell parameters with automatic caching of derived quantities (fractional matrix, volume, etc.).

class torchref.symmetry.cell.Cell(data, *, dtype=torch.float32, device=device(type='cpu'), requires_grad=False)[source]

Bases: DeviceMixin

Dataclass for crystallographic unit cells with cached derived quantities.

Stores 6 parameters: [a, b, c, alpha, beta, gamma] - a, b, c: cell lengths in Angstroms - alpha, beta, gamma: cell angles in degrees

Derived quantities (fractional_matrix, volume, etc.) are computed on first access and cached. The cache is cleared when the cell is moved to a different device or dtype.

Examples

>>> cell = Cell([50, 60, 70, 90, 90, 90])
>>> cell.volume  # Computed and cached
tensor(210000.)
>>> cell_gpu = cell.to('cuda')  # Move to GPU (returns new Cell)
>>> cell_gpu.device.type
'cuda'
__init__(data, *, dtype=torch.float32, device=device(type='cpu'), requires_grad=False)[source]

Create a new Cell.

Parameters:
  • data (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma]. Can be a list, numpy array, or torch tensor.

  • dtype (torch.dtype, optional) – Desired data type. Defaults to the configured dtypes.float.

  • device (torch.device or str, optional) – Desired device. Defaults to the configured device.current.

  • requires_grad (bool, optional) – Whether to track gradients. Defaults to False.

Raises:

ValueError – If data does not have exactly 6 elements.

reset_cache()[source]

Clear cached derived quantities (fractional matrix, volume, etc.).

detach()[source]

Return a new Cell with detached tensor (no gradient tracking).

Returns:

New Cell with detached data.

Return type:

Cell

clone()[source]

Return a new Cell with cloned tensor data.

Returns:

New Cell with cloned data.

Return type:

Cell

property device: device

Return the device of the underlying tensor.

property dtype: dtype

Return the dtype of the underlying tensor.

property data: Tensor

Return the underlying tensor (for buffer registration).

property requires_grad: bool

Return whether gradients are tracked.

property a: Tensor

Cell length a in Angstroms.

property b: Tensor

Cell length b in Angstroms.

property c: Tensor

Cell length c in Angstroms.

property alpha: Tensor

Cell angle alpha in degrees.

property beta: Tensor

Cell angle beta in degrees.

property gamma: Tensor

Cell angle gamma in degrees.

property fractional_matrix: Tensor

Orthogonalization matrix B (fractional -> Cartesian).

Returns the 3x3 matrix B such that: cart = frac @ B.T

Returns:

Shape (3, 3) orthogonalization matrix.

Return type:

torch.Tensor

property inv_fractional_matrix: Tensor

Fractionalization matrix B^-1 (Cartesian -> fractional).

Returns the 3x3 matrix B^-1 such that: frac = cart @ B^-1.T

Returns:

Shape (3, 3) fractionalization matrix.

Return type:

torch.Tensor

property volume: Tensor

Unit cell volume in cubic Angstroms.

Returns:

Scalar tensor with the cell volume.

Return type:

torch.Tensor

property reciprocal_basis_matrix: Tensor

Reciprocal basis matrix with [a*, b*, c*] as rows.

Returns:

Shape (3, 3) matrix where rows are the reciprocal basis vectors.

Return type:

torch.Tensor

compute_grid_size(max_res, oversampling=3.0)[source]

Compute minimum grid dimensions for a given resolution.

Uses Shannon-Nyquist sampling criterion to determine the minimum number of grid points needed along each axis.

Parameters:
  • max_res (float) – Maximum resolution in Angstroms.

  • oversampling (float, optional) – Oversampling factor relative to max_res. Default is 3.0 (standard for crystallographic calculations).

Returns:

Minimum grid dimensions (nx, ny, nz).

Return type:

tuple of int

Examples

>>> cell = Cell([50, 60, 70, 90, 90, 90])
>>> cell.compute_grid_size(2.0)
(75, 90, 105)
tolist()[source]

Convert Cell parameters to a standard Python list.

Returns:

List of cell parameters [a, b, c, alpha, beta, gamma].

Return type:

list

fractional_to_cartesian(frac_coords)[source]

Convert fractional coordinates to Cartesian coordinates.

Parameters:

frac_coords (torch.Tensor) – Tensor of fractional coordinates, shape (…, 3).

Returns:

Tensor of Cartesian coordinates, shape (…, 3).

Return type:

torch.Tensor

cartesian_to_fractional(cart_coords)[source]

Convert Cartesian coordinates to fractional coordinates.

Parameters:

cart_coords (torch.Tensor) – Tensor of Cartesian coordinates, shape (…, 3).

Returns:

Tensor of fractional coordinates, shape (…, 3).

Return type:

torch.Tensor

__repr__()[source]

Return string representation.

__getitem__(idx)[source]

Allow indexing like cell[0] for cell length a.

__len__()[source]

Return 6 (number of cell parameters).

torchref.symmetry.cell.CellTensor

alias of Cell