Source code for torchref.io.datasets.collection

"""
Dataset collection for handling multiple crystallographic datasets.

This module provides the DatasetCollection class for managing multiple
related ReflectionData objects, useful for joint refinement, MAD phasing,
and time-series crystallography.
"""

from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Tuple

import torch

from .base import CrystalDataset
from .reflection_data import ReflectionData


[docs] @dataclass class DatasetCollection(CrystalDataset): """ Container for multiple related crystal datasets. All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out. Parameters ---------- verbose : int, optional Verbosity level (0=silent, 1=normal, 2=debug). Default is 1. device : str, optional Device for tensors ('cpu', 'cuda', etc.). Defaults to the configured device.current. Attributes ---------- hkl : torch.Tensor Common HKL set for all datasets. n_datasets : int Number of datasets in collection. Examples -------- :: from torchref.io import DatasetCollection, ReflectionData collection = DatasetCollection(device='cuda') native = ReflectionData().load_mtz('native.mtz') derivative = ReflectionData().load_mtz('derivative.mtz') collection.add_dataset('native', native, set_as_reference=True) collection.add_dataset('derivative', derivative) for name, dataset in collection: print(f"{name}: {len(dataset)} reflections") # Access by name native_F = collection['native'].F """ # Collection-specific fields (not inherited from CrystalDataset) _datasets: Dict[str, ReflectionData] = field(default_factory=dict, repr=False) _dataset_order: List[str] = field(default_factory=list, repr=False) _reference_dataset: Optional[str] = field(default=None, repr=False) _common_hkl: Optional[torch.Tensor] = field(default=None, repr=False) _cell: Optional[torch.Tensor] = field(default=None, repr=False) _spacegroup: Optional[str] = field(default=None, repr=False) _resolution: Optional[torch.Tensor] = field(default=None, repr=False) _scale_factors: Dict[str, torch.Tensor] = field(default_factory=dict, repr=False)
[docs] def add_dataset( self, name: str, dataset: ReflectionData, set_as_reference: bool = False ) -> "DatasetCollection": """ Add a dataset to the collection. Parameters ---------- name : str Identifier for this dataset. dataset : ReflectionData The dataset to add. set_as_reference : bool, optional If True, use this dataset's HKL as the reference. Default is False, but the first dataset added automatically becomes the reference. Returns ------- DatasetCollection Self, for method chaining. Raises ------ ValueError If a dataset with the same name already exists. Examples -------- :: collection = DatasetCollection() collection.add_dataset('native', native_data, set_as_reference=True) collection.add_dataset('derivative', derivative_data) """ if name in self._datasets: raise ValueError(f"Dataset '{name}' already exists in collection") # First dataset or explicit reference becomes the reference if len(self._datasets) == 0 or set_as_reference: self._reference_dataset = name self._common_hkl = dataset.hkl.clone() if dataset.cell is not None: self._cell = dataset.cell.clone() self._spacegroup = dataset.spacegroup # Validate HKL against reference if self._common_hkl is not None and dataset.hkl is not None: dataset.validate_hkl(self._common_hkl) # Move dataset to same device as collection. ``.to`` is a no-op when # the dataset is already on this device (handles cuda/mps/cpu uniformly # and avoids the ``mps`` vs ``mps:0`` index-mismatch edge case). dataset.to(self.device) self._datasets[name] = dataset self._dataset_order.append(name) if self.verbose > 0: print(f"Added dataset '{name}' ({len(dataset)} reflections)") return self
@property def hkl(self) -> Optional[torch.Tensor]: """Common HKL set for all datasets.""" return self._common_hkl @hkl.setter def hkl(self, value: Optional[torch.Tensor]) -> None: """Set common HKL (redirects to _common_hkl).""" self._common_hkl = value @property def datasets(self) -> Dict[str, ReflectionData]: """Access all datasets as a dictionary.""" return self._datasets @property def n_datasets(self) -> int: """Number of datasets in collection.""" return len(self._datasets) @property def reference_dataset(self) -> Optional[str]: """Name of the reference dataset.""" return self._reference_dataset @property def spacegroup(self) -> Optional[str]: """Space group of the reference dataset.""" return self._spacegroup @spacegroup.setter def spacegroup(self, value: Optional[str]) -> None: """Set space group (redirects to _spacegroup).""" self._spacegroup = value
[docs] def __getitem__(self, name: str) -> ReflectionData: """ Get dataset by name. Parameters ---------- name : str Name of the dataset. Returns ------- ReflectionData The requested dataset. Raises ------ KeyError If dataset name not found. """ return self._datasets[name]
[docs] def __iter__(self) -> Iterator[Tuple[str, ReflectionData]]: """ Iterate over (name, dataset) pairs in order of addition. Yields ------ tuple of (str, ReflectionData) Name and dataset for each dataset in collection. """ for name in self._dataset_order: yield name, self._datasets[name]
[docs] def __len__(self) -> int: """Number of reflections in common HKL set.""" return len(self._common_hkl) if self._common_hkl is not None else 0
[docs] def __contains__(self, name: str) -> bool: """Check if dataset exists in collection.""" return name in self._datasets
def _calculate_resolution(self) -> None: """Calculate resolution for common HKL.""" from torchref.base import math_torch if self._common_hkl is None or self._cell is None: return s = math_torch.get_scattering_vectors(self._common_hkl, self._cell) resolution = 1.0 / torch.linalg.norm(s, axis=1) self._resolution = resolution
[docs] def __call__(self, mask: bool = True) -> Dict[str, Tuple]: """ Return all datasets' data scaled if scale factors are set. Parameters ---------- mask : bool, optional Whether to apply masking. Default is True. Returns ------- dict Dictionary mapping name to (hkl, F, F_sigma, rfree) tuples. """ data = {} return {name: ds(mask=mask, scale=True) for name, ds in self}
[docs] def scale(self): ''' Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps. The collection instance, allowing for method chaining. Raises ------ ValueError If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh... Notes ----- The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. """ Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset. Returns ------- self for method chaining. ''' if self._reference_dataset is None: raise ValueError("No reference dataset set for scaling") ref_ds = self._datasets[self._reference_dataset] to_scale = [ds for name, ds in self if name != self._reference_dataset] if not to_scale: raise ValueError("No datasets to scale against reference") parameters = [p for data in to_scale for p in data.parameters()] [p.requires_grad_(True) for p in parameters] optimizer = torch.optim.LBFGS(parameters, max_iter=100, line_search_fn='strong_wolfe') # Get masks once (they don't change during optimization) ref_mask = ref_ds.masks() ds_masks = [ds.masks() for ds in to_scale] def closure(): optimizer.zero_grad() loss = 0.0 # Get scaled reference data (bypassing MaskedTensor which doesn't support autograd) ref_F_scaled, _ = ref_ds.get_corrected_data() for ds, ds_mask in zip(to_scale, ds_masks): F_scaled, _ = ds.get_corrected_data() # Combine masks combined_mask = ds_mask & ref_mask # Compute loss on valid reflections only F_data = F_scaled[combined_mask] ref_F_data = ref_F_scaled[combined_mask] loss = loss + torch.sum((F_data - ref_F_data) ** 2) loss.backward() return loss for i in range(10): optimizer.step(closure) [p.requires_grad_(False) for p in parameters]
[docs] def keys(self) -> List[str]: """Return list of dataset names.""" return list(self._dataset_order)
[docs] def values(self) -> List[ReflectionData]: """Return list of datasets.""" return [self._datasets[name] for name in self._dataset_order]
[docs] def items(self) -> List[Tuple[str, ReflectionData]]: """Return list of (name, dataset) tuples.""" return [(name, self._datasets[name]) for name in self._dataset_order]
[docs] def get(self, name: str, default=None) -> Optional[ReflectionData]: """Get dataset by name with default fallback.""" return self._datasets.get(name, default)
[docs] def __repr__(self) -> str: """String representation of collection.""" n_datasets = self.n_datasets n_refl = len(self) sg = self.spacegroup or "unknown" names = ", ".join(self._dataset_order[:3]) if n_datasets > 3: names += f", ... ({n_datasets} total)" return ( f"DatasetCollection(datasets=[{names}], " f"n_reflections={n_refl}, spacegroup='{sg}', device={self.device})" )