Source code for torchref.io.datasets.base

"""
Base dataclass for crystallographic datasets.

This module defines the CrystalDataset dataclass that provides:
- All possible tensor fields for crystallographic data (optional)
- Device management (to, cuda, cpu)
- Serialization (save, load)

Space groups are stored as gemmi.SpaceGroup objects for consistency
and direct access to symmetry operations.
"""

from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import gemmi
import torch

from torchref.config import get_default_device
from torchref.symmetry import Cell
from torchref.utils.device_mixin import DeviceMovementMixin

if TYPE_CHECKING:
    pass


[docs] @dataclass class CrystalDataset(DeviceMovementMixin): """ Base dataclass for crystallographic datasets. Defines all possible tensor fields (optional) and handles device management and serialization. Subclasses add domain-specific methods. This lightweight design enables scaling to 1000s of datasets without the overhead of torch.nn.Module. Parameters ---------- device : torch.device Device for tensors ('cpu', 'cuda', etc.). Defaults to the configured device.current. verbose : int Verbosity level (0=silent, 1=normal, 2=debug). Default is 1. Examples -------- Basic usage:: data = CrystalDataset(device='cuda') data.hkl = torch.tensor([[1, 0, 0], [0, 1, 0]]) data.cpu() # Move all tensors to CPU """ # === Core reflection tensors === hkl: Optional[torch.Tensor] = None # Miller indices (N, 3), int32 F: Optional[torch.Tensor] = None # Structure factor amplitudes (N,) F_sigma: Optional[torch.Tensor] = None # Amplitude uncertainties (N,) I: Optional[torch.Tensor] = None # Intensities (N,) I_sigma: Optional[torch.Tensor] = None # Intensity uncertainties (N,) rfree_flags: Optional[torch.Tensor] = None # R-free test set flags (N,), int32 resolution: Optional[torch.Tensor] = None # Resolution per reflection (N,) bin_indices: Optional[torch.Tensor] = None # Resolution bin assignments (N,), int32 outlier_flags: Optional[torch.Tensor] = None # Outlier flags (N,), bool phase: Optional[torch.Tensor] = None # Phases in radians (N,) fom: Optional[torch.Tensor] = None # Figure of merit (N,) _centric_flags: Optional[torch.Tensor] = None # Centric flags (N,), bool # === E-value and anisotropy correction fields === E: Optional[torch.Tensor] = None # E-values (N,) E_squared: Optional[torch.Tensor] = None # E² values (N,) F_squared_corrected: Optional[torch.Tensor] = None # Anisotropy-corrected F² (N,) U_aniso: Optional[torch.Tensor] = None # Fitted anisotropy parameters (6,) radial_shell_indices: Optional[torch.Tensor] = None # Shell assignments (N,) # === Unit cell and symmetry === cell: Optional[Cell] = None # Cell object with [a, b, c, alpha, beta, gamma] spacegroup: Optional[str] = None # Space group name string # === Metadata === device: torch.device = field(default_factory=get_default_device) verbose: int = 1 # === Source tracking === rfree_source: Optional[str] = None amplitude_source: Optional[str] = None intensity_source: Optional[str] = None phase_source: Optional[str] = None # === Wilson B-factors === wilson_b: Optional[float] = None wilson_b_structure: Optional[float] = None wilson_b_solvent: Optional[float] = None wilson_k_sol: Optional[float] = None # === Outlier detection parameters === outlier_detection_params: Optional[Dict[str, Any]] = None # === Masks (initialized in __post_init__) === # Note: masks is not a dataclass field to avoid serialization issues # It's initialized in __post_init__ and handled specially
[docs] def __post_init__(self): """Initialize non-field attributes after dataclass init.""" # Ensure device is a torch.device object if isinstance(self.device, str): object.__setattr__(self, "device", torch.device(self.device)) # Import here to avoid circular imports from torchref.utils.utils import TensorMasks # Initialize masks as TensorMasks (dict subclass) if not hasattr(self, "masks") or self.masks is None: self.masks = TensorMasks(device=self.device)
# ========== DEVICE MANAGEMENT ========== def _tensor_fields(self): """ Yield (name, tensor) for all tensor attributes. Yields ------ Tuple[str, torch.Tensor] Field name and tensor value for each tensor field. Note ---- Cell objects are excluded; they are handled separately in to(). """ for f in fields(self): val = getattr(self, f.name) if isinstance(val, torch.Tensor): yield f.name, val # ========== SERIALIZATION ========== def _get_state(self) -> Dict[str, Any]: """ Get serializable state dictionary. Returns ------- Dict[str, Any] State dictionary with all tensor and metadata fields. """ state = {} for f in fields(self): val = getattr(self, f.name) if isinstance(val, torch.Tensor): state[f.name] = val.cpu() elif f.name == "cell" and val is not None: # Store Cell tensor data for serialization state[f.name] = val.data.cpu() elif f.name == "device": # Store device as string state[f.name] = str(val) elif f.name == "spacegroup" and val is not None: # Store spacegroup as string for serialization state[f.name] = val.xhm() # Extended Hermann-Mauguin else: state[f.name] = val # Handle masks specially if hasattr(self, "masks") and self.masks is not None: state["masks"] = {k: v.cpu() for k, v in self.masks.items()} return state @classmethod def _from_state( cls, state: Dict[str, Any], device=get_default_device() ) -> "CrystalDataset": """ Reconstruct from state dictionary. Parameters ---------- state : Dict[str, Any] State dictionary from _get_state(). device : str Device to load tensors onto. Returns ------- CrystalDataset Reconstructed dataset. """ from torchref.utils.utils import TensorMasks # Extract masks before creating object masks_state = state.pop("masks", {}) # Convert device string back to torch.device if "device" in state: state["device"] = torch.device(state["device"]) # Spacegroup is stored as string — keep as-is # (no conversion needed since CrystalDataset.spacegroup is now str) # Convert cell tensor back to Cell object if "cell" in state and state["cell"] is not None: if isinstance(state["cell"], torch.Tensor): state["cell"] = Cell(state["cell"], dtype=torch.float32, device=device) # Create object with remaining state obj = cls(**state) # Restore masks if masks_state: obj.masks = TensorMasks(data=masks_state, device=device) return obj.to(device)
[docs] def save_state(self, path: str) -> None: """ Save dataset state to file. Parameters ---------- path : str Output file path. Examples -------- Save to file:: data.save_state('reflection_data.pt') """ state = self._get_state() state["__class__"] = self.__class__.__name__ torch.save(state, path) if self.verbose > 0: print(f"Saved {self.__class__.__name__} to {path}")
[docs] @classmethod def load_state(cls, path: str, device=get_default_device()) -> "CrystalDataset": """ Load dataset state from file. Parameters ---------- path : str Input file path. device : str Device to load tensors onto. Returns ------- CrystalDataset Loaded dataset. Examples -------- Load from file:: data = ReflectionData.load_state('reflection_data.pt', device='cuda') """ state = torch.load(path, map_location="cpu") # Remove class marker if present state.pop("__class__", None) obj = cls._from_state(state, device) if obj.verbose > 0: print(f"Loaded {cls.__name__} from {path}") return obj
# ========== UTILITY METHODS ==========
[docs] def __len__(self) -> int: """Return number of reflections in dataset.""" if self.hkl is not None: return len(self.hkl) return 0
[docs] def __repr__(self) -> str: """String representation of dataset.""" n_refl = len(self) sg = self.spacegroup if self.spacegroup else "unknown" return f"{self.__class__.__name__}(n_reflections={n_refl}, spacegroup='{sg}', device={self.device})"
@property def spacegroup_name(self) -> Optional[str]: """Get space group name as string (short form, e.g., 'P212121').""" if self.spacegroup is None: return None return gemmi.SpaceGroup(self.spacegroup).short_name() @property def spacegroup_hm(self) -> Optional[str]: """Get space group Hermann-Mauguin name with spaces (e.g., 'P 21 21 21').""" if self.spacegroup is None: return None return gemmi.SpaceGroup(self.spacegroup).hm @property def spacegroup_number(self) -> Optional[int]: """Get space group number (1-230).""" if self.spacegroup is None: return None return gemmi.SpaceGroup(self.spacegroup).number