Source code for torchref.base.reciprocal.hkl

"""
Miller index (HKL) operations for crystallography.

Functions for generating Miller indices, computing d-spacings,
and other HKL-related calculations.
"""

from typing import Optional

import torch

from torchref.config import dtypes
from .basis import reciprocal_basis_matrix, get_scattering_vectors


[docs] def get_d_spacing(hkl: torch.Tensor, cell: torch.Tensor, recB=None): """ Calculate d-spacing from Miller indices. Parameters ---------- hkl : torch.Tensor Miller indices of shape (N, 3). cell : torch.Tensor Cell parameters [a, b, c, alpha, beta, gamma]. recB : torch.Tensor, optional Pre-computed reciprocal basis matrix of shape (3, 3). Returns ------- torch.Tensor D-spacing values of shape (N,) in Angstroms. """ s = get_scattering_vectors(hkl, cell, recB) d_spacing = 1.0 / torch.linalg.norm(s, axis=1) return d_spacing
[docs] def compute_d_spacing_batch( hkl: torch.Tensor, cell: torch.Tensor, recB: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Compute d-spacing for a batch of Miller indices. Wrapper around get_d_spacing for convenience. Parameters ---------- hkl : torch.Tensor, shape (N, 3) Miller indices. cell : torch.Tensor, shape (6,) Unit cell parameters. recB : torch.Tensor, optional Pre-computed reciprocal basis matrix. Returns ------- torch.Tensor, shape (N,) D-spacing values in Angstroms. """ return get_d_spacing(hkl, cell, recB=recB)
[docs] def generate_possible_hkl( cell: torch.Tensor, d_min: float, device: Optional[torch.device] = None ) -> torch.Tensor: """ Generate all possible Miller indices within a resolution limit. Creates a complete set of (h, k, l) indices where the d-spacing is greater than or equal to d_min. Parameters ---------- cell : torch.Tensor, shape (6,) Unit cell parameters [a, b, c, alpha, beta, gamma] in Angstroms and degrees. d_min : float High resolution limit in Angstroms (minimum d-spacing). device : torch.device, optional Device for computation. If None, uses cell's device. Returns ------- torch.Tensor, shape (M, 3), dtype int32 All Miller indices with d-spacing >= d_min. Examples -------- :: import torch cell = torch.tensor([50.0, 60.0, 70.0, 90.0, 90.0, 90.0]) hkl = generate_possible_hkl(cell, d_min=2.0) print(f"Generated {len(hkl)} reflections") """ if device is None: device = cell.device cell = cell.to(device) # Compute reciprocal basis to get a*, b*, c* lengths recB = reciprocal_basis_matrix(cell) a_star = torch.linalg.norm(recB[0]) b_star = torch.linalg.norm(recB[1]) c_star = torch.linalg.norm(recB[2]) # Maximum h, k, l values (conservative estimate) # d = 1/|s| where s = h*a* + k*b* + l*c* # For d >= d_min, we need |s| <= 1/d_min # Conservative bound: h_max = ceil(1 / (d_min * a*)) s_max = 1.0 / d_min h_max = int(torch.ceil(s_max / a_star).item()) k_max = int(torch.ceil(s_max / b_star).item()) l_max = int(torch.ceil(s_max / c_star).item()) # Generate all combinations of h, k, l # Include negative values for completeness h_range = torch.arange(-h_max, h_max + 1, device=device, dtype=dtypes.int) k_range = torch.arange(-k_max, k_max + 1, device=device, dtype=dtypes.int) l_range = torch.arange(-l_max, l_max + 1, device=device, dtype=dtypes.int) # Create meshgrid of all combinations hh, kk, ll = torch.meshgrid(h_range, k_range, l_range, indexing="ij") hkl_all = torch.stack([hh.flatten(), kk.flatten(), ll.flatten()], dim=1) # Remove (0, 0, 0) - not a valid reflection not_origin = (hkl_all != 0).any(dim=1) hkl_all = hkl_all[not_origin] # Filter by resolution d_spacing = get_d_spacing(hkl_all.float(), cell, recB=recB) valid_res = d_spacing >= d_min hkl_valid = hkl_all[valid_res] return hkl_valid