Source code for torchref.io.datasets.fcalc_data

"""
FcalcDataset - Dataset for storing calculated structure factors.

This module provides a lightweight container for calculated structure factors
(Fcalc) with support for:
- Creation from cell/spacegroup/resolution
- Complex Fcalc with amplitude/phase decomposition
- MTZ file export
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import pandas as pd
import torch

from torchref.config import get_default_device, get_float_dtype
from torchref.symmetry import Cell, SpaceGroup, SpaceGroupLike

from .base import CrystalDataset


[docs] @dataclass class FcalcDataset(CrystalDataset): """ Dataset for storing calculated structure factors. Provides a lightweight container for Fcalc values with: - Cell and spacegroup information (using torchref.symmetry types) - HKL indices and resolution - Complex Fcalc with amplitude/phase decomposition - MTZ export capability This class inherits from CrystalDataset and overrides the spacegroup field to store torchref.symmetry.SpaceGroup instead of gemmi.SpaceGroup. Parameters ---------- hkl : torch.Tensor, optional Miller indices of shape (N, 3). resolution : torch.Tensor, optional Resolution per reflection of shape (N,). cell : Cell, optional Unit cell object. spacegroup : SpaceGroup, optional Space group object (torchref.symmetry.SpaceGroup). fcalc : torch.Tensor, optional Complex structure factors of shape (N,). fcalc_amp : torch.Tensor, optional Amplitudes |Fcalc| of shape (N,). fcalc_phase : torch.Tensor, optional Phases in radians of shape (N,). device : torch.device Device for tensors. Examples -------- Create from cell and resolution:: from torchref.io.datasets import FcalcDataset dataset = FcalcDataset.from_cell_and_resolution( cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0], spacegroup='P212121', d_min=2.0, ) # Set Fcalc values (complex tensor) fcalc = torch.randn(len(dataset), dtype=torch.complex64) dataset.set_fcalc(fcalc) # Write to MTZ dataset.write_mtz('output.mtz') """ # Override spacegroup to use torchref.symmetry.SpaceGroup (not gemmi) spacegroup: Optional[SpaceGroup] = None # type: ignore[assignment] # Fcalc-specific fields fcalc: Optional[torch.Tensor] = None # Complex (N,) fcalc_amp: Optional[torch.Tensor] = None # |Fcalc| (N,) fcalc_phase: Optional[torch.Tensor] = None # Phase in radians (N,)
[docs] @staticmethod def from_cell_and_resolution( cell: Union[torch.Tensor, List[float], Cell], spacegroup: SpaceGroupLike, d_min: float = 2.0, d_max: Optional[float] = None, device: torch.device = get_default_device(), dtype: torch.dtype = get_float_dtype(), ) -> "FcalcDataset": """ Create FcalcDataset with HKL generated to given resolution. Parameters ---------- cell : torch.Tensor, list, or Cell Unit cell [a, b, c, alpha, beta, gamma] or Cell object. spacegroup : SpaceGroupLike Space group (str, int, gemmi.SpaceGroup, or torchref.symmetry.SpaceGroup). d_min : float, optional High resolution limit in Angstroms. Default is 2.0. d_max : float, optional Low resolution limit in Angstroms. If provided, reflections with d-spacing > d_max are removed. device : torch.device Target device. dtype : torch.dtype Float dtype for tensors. Returns ------- FcalcDataset New dataset with HKL and resolution populated. Examples -------- :: from torchref.symmetry import Cell, SpaceGroup cell = Cell([50.0, 60.0, 70.0, 90.0, 90.0, 90.0]) sg = SpaceGroup('P212121') dataset = FcalcDataset.from_cell_and_resolution( cell=cell, spacegroup=sg, d_min=2.0, ) print(f"Generated {len(dataset)} reflections") """ import gemmi from torchref.base.reciprocal import get_d_spacing # Handle Cell input - convert to Cell object if needed if isinstance(cell, Cell): cell_obj = cell.to(device=device) cell_tensor = cell_obj.data else: if not isinstance(cell, torch.Tensor): cell_tensor = torch.tensor(cell, dtype=dtype, device=device) else: cell_tensor = cell.to(device=device, dtype=dtype) cell_obj = Cell(cell_tensor, dtype=dtype, device=device) # Handle spacegroup - normalize to torchref.symmetry.SpaceGroup if isinstance(spacegroup, SpaceGroup): sg_obj = spacegroup else: sg_obj = SpaceGroup(spacegroup) # Normalize any input # Generate HKL indices using gemmi (respects space group symmetry) # This generates only unique reflections in the asymmetric unit cell_list = cell_tensor.cpu().tolist() gemmi_cell = gemmi.UnitCell( cell_list[0], cell_list[1], cell_list[2], cell_list[3], cell_list[4], cell_list[5] ) gemmi_sg = sg_obj._gemmi # Get underlying gemmi.SpaceGroup # make_miller_array returns unique HKL for the asymmetric unit hkl_list = gemmi.make_miller_array(gemmi_cell, gemmi_sg, d_min) hkl = torch.tensor(hkl_list, dtype=torch.int32, device=device) # Calculate resolution resolution = get_d_spacing(hkl.float(), cell_tensor) # Apply low resolution cutoff if requested if d_max is not None: mask = resolution <= d_max hkl = hkl[mask] resolution = resolution[mask] print(f"Generated dataset with {len(hkl)} reflections.") return FcalcDataset( hkl=hkl, resolution=resolution, cell=cell_obj, spacegroup=sg_obj, device=device, )
[docs] def set_fcalc(self, fcalc: torch.Tensor) -> None: """ Assign complex Fcalc values. Automatically computes amplitude and phase from complex values. Parameters ---------- fcalc : torch.Tensor Complex structure factors with shape (N,). Raises ------ ValueError If fcalc length doesn't match HKL length. Examples -------- :: # Create complex Fcalc values fcalc = torch.randn(len(dataset), dtype=torch.complex64) dataset.set_fcalc(fcalc) print(dataset.fcalc_amp[:5]) # Amplitudes print(dataset.fcalc_phase[:5]) # Phases in radians """ if self.hkl is None: raise ValueError("HKL not set. Cannot assign Fcalc without HKL indices.") if fcalc.shape[0] != len(self.hkl): raise ValueError( f"Fcalc length {fcalc.shape[0]} != HKL length {len(self.hkl)}" ) self.fcalc = fcalc.to(device=self.device) self.fcalc_amp = torch.abs(fcalc).to(device=self.device) self.fcalc_phase = torch.angle(fcalc).to(device=self.device)
[docs] def write_mtz(self, filepath: str) -> None: """ Write Fcalc to MTZ file. Parameters ---------- filepath : str Output MTZ filename. Raises ------ ValueError If no Fcalc values have been set. Examples -------- :: dataset.set_fcalc(fcalc_values) dataset.write_mtz('calculated.mtz') """ from torchref.io import mtz if self.fcalc is None: raise ValueError("No Fcalc values set. Call set_fcalc() first.") if self.hkl is None: raise ValueError("No HKL indices set.") if self.cell is None: raise ValueError("No cell set.") if self.spacegroup is None: raise ValueError("No spacegroup set.") # Build DataFrame hkl_np = self.hkl.cpu().numpy() df = pd.DataFrame( { "H": hkl_np[:, 0], "K": hkl_np[:, 1], "L": hkl_np[:, 2], "F-model": self.fcalc_amp.cpu().numpy(), "PH-model": torch.rad2deg(self.fcalc_phase).cpu().numpy(), } ) # Write using existing mtz.write() - pass SpaceGroup wrapper mtz.write(df, self.cell.data, self.spacegroup, filepath)
[docs] def write_mtz_as_fobs( self, filepath: str, sigma_frac: float = 0.05, f_column: str = "F-obs", sigf_column: str = "SIGF-obs", phase_column: str = "PHIF-model", ) -> None: """ Write Fcalc to MTZ as if it were observed data (F-obs columns). Useful for creating simulated "experimental" MTZ files that can be read back by ReflectionData.load_mtz() as observed amplitudes. Parameters ---------- filepath : str Output MTZ filename. sigma_frac : float, optional Sigma as a fraction of |F|. Default is 0.05 (5%). f_column : str, optional Column name for amplitudes. Default is 'F-obs'. sigf_column : str, optional Column name for sigma. Default is 'SIGF-obs'. phase_column : str, optional Column name for model phases. Default is 'PHIF-model'. Examples -------- :: dataset.set_fcalc(fcalc_values) dataset.write_mtz_as_fobs('simulated_obs.mtz', sigma_frac=0.05) """ from torchref.io import mtz if self.fcalc_amp is None: raise ValueError("No Fcalc values set. Call set_fcalc() first.") if self.hkl is None: raise ValueError("No HKL indices set.") if self.cell is None: raise ValueError("No cell set.") if self.spacegroup is None: raise ValueError("No spacegroup set.") amp = self.fcalc_amp.cpu().numpy() sigma = amp * sigma_frac hkl_np = self.hkl.cpu().numpy() columns = { "H": hkl_np[:, 0], "K": hkl_np[:, 1], "L": hkl_np[:, 2], f_column: amp, sigf_column: sigma, } if self.fcalc_phase is not None: columns[phase_column] = ( torch.rad2deg(self.fcalc_phase).cpu().numpy() ) df = pd.DataFrame(columns) mtz.write(df, self.cell.data, self.spacegroup, filepath)
# ========== SERIALIZATION OVERRIDES ========== def _get_state(self) -> Dict[str, Any]: """ Get serializable state, handling SpaceGroup wrapper. Returns ------- Dict[str, Any] State dictionary with all tensor and metadata fields. """ state = super()._get_state() # SpaceGroup is stored as torchref.symmetry.SpaceGroup, serialize as string if self.spacegroup is not None: state["spacegroup"] = self.spacegroup.hm # Use hm property return state @classmethod def _from_state( cls, state: Dict[str, Any], device=get_default_device() ) -> "FcalcDataset": """ Reconstruct from state, creating SpaceGroup wrapper. Parameters ---------- state : Dict[str, Any] State dictionary from _get_state(). device : str Device to load tensors onto. Returns ------- FcalcDataset Reconstructed dataset. """ from torchref.utils.utils import TensorMasks # Extract masks before creating object masks_state = state.pop("masks", {}) # Convert device string back to torch.device if "device" in state: state["device"] = torch.device(state["device"]) # Convert spacegroup string to SpaceGroup object if "spacegroup" in state and state["spacegroup"] is not None: if isinstance(state["spacegroup"], str): state["spacegroup"] = SpaceGroup(state["spacegroup"]) # Convert cell tensor back to Cell object if "cell" in state and state["cell"] is not None: if isinstance(state["cell"], torch.Tensor): state["cell"] = Cell(state["cell"], dtype=get_float_dtype(), device=device) # Create object with remaining state obj = cls(**state) # Restore masks if masks_state: obj.masks = TensorMasks(data=masks_state, device=device) return obj.to(device) # ========== UTILITY METHODS ==========
[docs] def __repr__(self) -> str: """String representation of dataset.""" n_refl = len(self) sg = self.spacegroup.name if self.spacegroup else "unknown" has_fcalc = "yes" if self.fcalc is not None else "no" return ( f"{self.__class__.__name__}(n_reflections={n_refl}, " f"spacegroup='{sg}', fcalc={has_fcalc}, device={self.device})" )
@property def spacegroup_name(self) -> Optional[str]: """Get space group name as string (short form, e.g., 'P212121').""" if self.spacegroup is None: return None return self.spacegroup.name @property def spacegroup_hm(self) -> Optional[str]: """Get space group Hermann-Mauguin name with spaces (e.g., 'P 21 21 21').""" if self.spacegroup is None: return None return self.spacegroup.hm @property def spacegroup_number(self) -> Optional[int]: """Get space group number (1-230).""" if self.spacegroup is None: return None return self.spacegroup.number