torchref.io.datasets.base module

Base dataclass for crystallographic datasets.

This module defines the CrystalDataset dataclass that provides: - All possible tensor fields for crystallographic data (optional) - Device management (to, cuda, cpu) - Serialization (save, load)

Space groups are stored as gemmi.SpaceGroup objects for consistency and direct access to symmetry operations.

class torchref.io.datasets.base.CrystalDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)[source]

Bases: DeviceMixin

Base dataclass for crystallographic datasets.

Defines all possible tensor fields (optional) and handles device management and serialization. Subclasses add domain-specific methods.

This lightweight design enables scaling to 1000s of datasets without the overhead of torch.nn.Module.

Parameters:
  • device (torch.device) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.

  • verbose (int) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.

Examples

Basic usage:

data = CrystalDataset(device='cuda')
data.hkl = torch.tensor([[1, 0, 0], [0, 1, 0]])
data.cpu()  # Move all tensors to CPU
hkl: Tensor | None = None
F: Tensor | None = None
F_sigma: Tensor | None = None
I: Tensor | None = None
I_sigma: Tensor | None = None
rfree_flags: Tensor | None = None
resolution: Tensor | None = None
bin_indices: Tensor | None = None
outlier_flags: Tensor | None = None
phase: Tensor | None = None
fom: Tensor | None = None
E: Tensor | None = None
E_squared: Tensor | None = None
F_squared_corrected: Tensor | None = None
U_aniso: Tensor | None = None
radial_shell_indices: Tensor | None = None
cell: Cell | None = None
spacegroup: str | None = None
device: device
verbose: int = 1
rfree_source: str | None = None
amplitude_source: str | None = None
intensity_source: str | None = None
phase_source: str | None = None
wilson_b: float | None = None
wilson_b_structure: float | None = None
wilson_b_solvent: float | None = None
wilson_k_sol: float | None = None
outlier_detection_params: Dict[str, Any] | None = None
__post_init__()[source]

Initialize non-field attributes after dataclass init.

save_state(path)[source]

Save dataset state to file.

Parameters:

path (str) – Output file path.

Examples

Save to file:

data.save_state('reflection_data.pt')
classmethod load_state(path, device=device(type='cpu'))[source]

Load dataset state from file.

Parameters:
  • path (str) – Input file path.

  • device (str) – Device to load tensors onto.

Returns:

Loaded dataset.

Return type:

CrystalDataset

Examples

Load from file:

data = ReflectionData.load_state('reflection_data.pt', device='cuda')
__len__()[source]

Return number of reflections in dataset.

__repr__()[source]

String representation of dataset.

property spacegroup_name: str | None

Get space group name as string (short form, e.g., ‘P212121’).

property spacegroup_hm: str | None

Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).

property spacegroup_number: int | None

Get space group number (1-230).

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)