Source code for torchref.io.mtz

"""
MTZ file format reading and writing.

This module provides functions for reading and writing MTZ files containing
crystallographic reflection data (structure factor amplitudes, intensities,
R-free flags, etc.).

Space groups are returned as gemmi.SpaceGroup objects for consistency
throughout torchref.

Functions
---------
read
    Read an MTZ file and return a reader object.
write
    Write reflection data to an MTZ file.

Classes
-------
MTZReader
    Reader class for MTZ files.

Examples
--------
::

    from torchref.io import mtz

    # Reading
    reader = mtz.read('data.mtz', verbose=1)
    data_dict, cell, spacegroup = reader()
    print(spacegroup.short_name())  # gemmi.SpaceGroup object

    # Writing
    mtz.write(df, cell, spacegroup, 'output.mtz')
"""

from typing import Optional, Tuple, Union

import gemmi
import numpy as np
import pandas as pd
import reciprocalspaceship as rs
import torch


[docs] class MTZReader: """ Reader for MTZ files containing crystallographic structure factor data. This class reads MTZ files using reciprocalspaceship and extracts: - Miller indices (h, k, l) - Structure factor amplitudes or intensities - Associated uncertainties (sigma values) - R-free test set flags Attributes ---------- verbose : int Verbosity level for logging (0=silent, 1=normal, 2=debug). data : dict Dictionary containing extracted data arrays. cell : np.ndarray Unit cell parameters [a, b, c, alpha, beta, gamma]. spacegroup : gemmi.SpaceGroup Space group object. Examples -------- :: reader = mtz.read('data.mtz', verbose=1) data_dict, cell, spacegroup = reader() print(f"Found {len(data_dict['HKL'])} reflections in {spacegroup.short_name()}") """ AMPLITUDE_PRIORITY = [ "F-obs", "FOBS", "FP", "F", "F-obs-filtered", "FOBS-filtered", "F(+)", "FPLUS", "FMEAN", "F-pk", "F_pk", "FO", "FODD", "F-model", "FC", "FCALC", ] INTENSITY_PRIORITY = [ "I-obs", "IOBS", "I", "IMEAN", "I-obs-filtered", "IOBS-filtered", "I(+)", "IPLUS", "IP", "I-pk", "I_pk", "IHLI", "I_full", "IOBS_full", "IO", ] RFREE_FLAG_NAMES = [ "R-free-flags", "RFREE", "FreeR_flag", "FREE", "R-free", "Rfree", "FREER", "FREE_FLAG", "test", "TEST", "free", "Free", ]
[docs] def __init__(self, verbose: int = 0, column_names: Optional[dict] = None): """ Initialize MTZ reader. Parameters ---------- verbose : int, optional Verbosity level (0=silent, 1=normal, 2=debug). Default is 0. column_names : dict, optional Explicit column name mapping to override automatic detection. Supported keys: ``"F"``, ``"SIGF"``, ``"I"``, ``"SIGI"``. Example: ``{"F": "DFo", "SIGF": "sig_DFo"}``. """ self.verbose = verbose self.column_names = column_names or {} self.data = None self.cell = None self.spacegroup = None self.mtz_data = None
[docs] def read(self, filepath: str) -> "MTZReader": """ Read an MTZ file and extract reflection data. Parameters ---------- filepath : str Path to the MTZ file. Returns ------- MTZReader Self, for method chaining. """ self.data = dict() if self.verbose > 1: print(f"Reading MTZ file: {filepath}") self.mtz_data = rs.read_mtz(filepath) self.cell = np.array( [ self.mtz_data.cell.a, self.mtz_data.cell.b, self.mtz_data.cell.c, self.mtz_data.cell.alpha, self.mtz_data.cell.beta, self.mtz_data.cell.gamma, ] ) hkl = self.mtz_data.reset_index()[["H", "K", "L"]].to_numpy().astype(np.int32) self.data["HKL"] = hkl # Store as string (the caller wraps in SpaceGroup as needed) self.spacegroup = self.mtz_data.spacegroup.hm self._extract_amplitudes_and_intensities() self._extract_rfree_flags() return self
[docs] def __call__(self) -> Tuple[dict, np.ndarray, str]: """ Return extracted data in a standardized format. Returns ------- data : dict Dictionary with extracted data arrays. cell : np.ndarray Unit cell parameters [a, b, c, alpha, beta, gamma]. spacegroup : str Space group name string. """ if self.data is None: raise ValueError("No data loaded. Call read() first.") return self.data, self.cell, self.spacegroup
def _extract_amplitudes_and_intensities(self) -> None: """Extract amplitude and intensity data with priority ordering. If ``column_names`` were provided at init, those columns are used directly instead of the priority-based search. """ available_cols = set(self.mtz_data.columns) # --- Explicit column names override priority search --- if "I" in self.column_names: intensity_col = self.column_names["I"] if intensity_col not in available_cols: raise ValueError( f"Intensity column '{intensity_col}' not found in MTZ. " f"Available: {sorted(available_cols)}" ) else: intensity_col = None for col in self.INTENSITY_PRIORITY: if col in available_cols: dtype = str(self.mtz_data.dtypes[col]) if "Intensity" in dtype or "J" in dtype: intensity_col = col break if "F" in self.column_names: amplitude_col = self.column_names["F"] if amplitude_col not in available_cols: raise ValueError( f"Amplitude column '{amplitude_col}' not found in MTZ. " f"Available: {sorted(available_cols)}" ) else: amplitude_col = None for col in self.AMPLITUDE_PRIORITY: if col in available_cols: dtype = str(self.mtz_data.dtypes[col]) if "SFAmplitude" in dtype or "F" in dtype: amplitude_col = col break # Extract intensity data if intensity_col: self.data["I"] = self.mtz_data[intensity_col].to_numpy().astype(np.float32) self.data["I_col"] = intensity_col if "SIGI" in self.column_names: scol = self.column_names["SIGI"] if scol in available_cols: self.data["SIGI"] = self.mtz_data[scol].to_numpy().astype(np.float32) self.data["SIGI_col"] = scol else: sigma_col = self._find_sigma_column(intensity_col, is_intensity=True) if sigma_col: self.data["SIGI"] = ( self.mtz_data[sigma_col].to_numpy().astype(np.float32) ) self.data["SIGI_col"] = sigma_col # Extract amplitude data if amplitude_col: self.data["F"] = self.mtz_data[amplitude_col].to_numpy().astype(np.float32) self.data["F_col"] = amplitude_col if "SIGF" in self.column_names: scol = self.column_names["SIGF"] if scol in available_cols: self.data["SIGF"] = self.mtz_data[scol].to_numpy().astype(np.float32) self.data["SIGF_col"] = scol else: sigma_col = self._find_sigma_column(amplitude_col, is_intensity=False) if sigma_col: self.data["SIGF"] = ( self.mtz_data[sigma_col].to_numpy().astype(np.float32) ) self.data["SIGF_col"] = sigma_col def _extract_rfree_flags(self) -> None: """Extract R-free flags from the dataset.""" available_cols = set(self.mtz_data.columns) for col in self.RFREE_FLAG_NAMES: if col in available_cols: dtype = str(self.mtz_data.dtypes[col]) if "int" in dtype.lower() or "flag" in dtype.lower() or "I" in dtype: try: flags = self.mtz_data[col].to_numpy() if flags.dtype == object or not np.issubdtype( flags.dtype, np.integer ): flags = pd.to_numeric(flags, errors="coerce") flags = np.nan_to_num(flags, nan=-1).astype(np.int32) else: flags = flags.astype(np.int32) rfree_flags = np.array(flags, dtype=np.int32) n_free = (rfree_flags == 0).sum() free_pct = ( 100.0 * n_free / len(rfree_flags) if len(rfree_flags) > 0 else 0 ) # Flip convention if needed if free_pct > 50.0: flipped = np.zeros_like(rfree_flags) flipped[rfree_flags == 0] = 1 flipped[rfree_flags > 0] = 0 flipped[rfree_flags < 0] = -1 rfree_flags = flipped if self.verbose > 0: n_free = (rfree_flags == 0).sum() free_pct = 100.0 * n_free / len(rfree_flags) print(f" After flip: free={n_free} ({free_pct:.1f}%)") self.data["R-free-flags"] = rfree_flags.astype(bool) self.data["R-free-source"] = col return except Exception as e: if self.verbose > 0: print( f"Warning: Could not load R-free flags from {col}: {e}" ) def _find_sigma_column(self, data_col: str, is_intensity: bool) -> Optional[str]: """Find the sigma column for a data column.""" available_cols = set(self.mtz_data.columns) sigma_variants = [ f"SIG{data_col}", f"SIGM{data_col}", f"{data_col}_sigma", f"{data_col}-sigma", ] if is_intensity: sigma_variants.extend( [ data_col.replace("I", "SIGI", 1), data_col.replace("I-", "SIGI-"), "SIGI", "SIGIMEAN", "SIGI-obs", "SIGIOBS", ] ) else: sigma_variants.extend( [ data_col.replace("F", "SIGF", 1), data_col.replace("F-", "SIGF-"), "SIGF", "SIGFOBS", "SIGF-obs", "SIGFP", ] ) for sigma_col in sigma_variants: if sigma_col in available_cols: dtype = str(self.mtz_data.dtypes[sigma_col]) if "Stddev" in dtype or "Sigma" in dtype or "SIG" in sigma_col.upper(): return sigma_col return None
[docs] def read(filepath: str, verbose: int = 0) -> MTZReader: """ Read an MTZ file. Parameters ---------- filepath : str Path to the MTZ file. verbose : int, optional Verbosity level. Default is 0. Returns ------- MTZReader Reader object with data loaded. """ return MTZReader(verbose=verbose).read(filepath)
[docs] def write( df: pd.DataFrame, cell: Union[list, np.ndarray, torch.Tensor], spacegroup: Union[str, gemmi.SpaceGroup], filepath: str, ) -> int: """ Write a DataFrame to an MTZ file. Parameters ---------- df : pandas.DataFrame DataFrame containing reflection data. Expected columns include H, K, L (Miller indices) and data columns like F_obs, I_obs, etc. cell : list, numpy.ndarray, or torch.Tensor Unit cell parameters [a, b, c, alpha, beta, gamma] in A and degrees. spacegroup : str or gemmi.SpaceGroup Space group symbol or gemmi SpaceGroup object. filepath : str Output MTZ filename. Returns ------- int Returns 1 on success. """ import gemmi if torch.is_tensor(cell): cell = cell.detach().cpu().numpy().tolist() elif isinstance(cell, np.ndarray): cell = cell.tolist() cell = gemmi.UnitCell(*cell) # Handle spacegroup — normalize to gemmi.SpaceGroup for reciprocalspaceship from torchref.symmetry import SpaceGroup as TorchRefSpaceGroup if isinstance(spacegroup, TorchRefSpaceGroup): spacegroup = spacegroup._gemmi elif isinstance(spacegroup, gemmi.SpaceGroup): pass elif isinstance(spacegroup, str): if spacegroup.startswith("<gemmi.SpaceGroup"): import re match = re.search(r'SpaceGroup\("([^"]+)"\)', spacegroup) if match: spacegroup = gemmi.SpaceGroup(match.group(1)) else: raise ValueError(f"Could not parse spacegroup string: {spacegroup}") else: spacegroup = gemmi.SpaceGroup(spacegroup) else: raise ValueError( f"Spacegroup must be str, gemmi.SpaceGroup, or torchref SpaceGroup, got {type(spacegroup)}" ) mtz_rs = rs.DataSet(df, cell=cell, spacegroup=spacegroup) # Assign MTZ data types structure_factor_cols = ["F-obs", "Fobs", "FP", "2FOFCWT", "FOFCWT", "F-model", "FWT", "DELFWT"] intensity_cols = ["I-obs", "I"] sigma_cols = ["SIGF-obs", "SIGI-obs", "SIGFP", "SIGI"] phase_cols = ["PH2FOFCWT", "PHFOFCWT", "PH-model", "PHWT", "PHDELWT"] flags = ["R-free-flags", "FreeR_flag", "FREE"] if "H" in mtz_rs.columns and "K" in mtz_rs.columns and "L" in mtz_rs.columns: mtz_rs['H'] = mtz_rs['H'].astype('H') mtz_rs['K'] = mtz_rs['K'].astype('H') mtz_rs['L'] = mtz_rs['L'].astype('H') mtz_rs = mtz_rs.set_index("H", "K", "L") for col in structure_factor_cols: if col in mtz_rs.columns: mtz_rs[col] = mtz_rs[col].astype("F") for col in intensity_cols: if col in mtz_rs.columns: mtz_rs[col] = mtz_rs[col].astype("J") for col in sigma_cols: if col in mtz_rs.columns: mtz_rs[col] = mtz_rs[col].astype("Q") for col in phase_cols: if col in mtz_rs.columns: mtz_rs[col] = mtz_rs[col].astype("P") for col in flags: if col in mtz_rs.columns: mtz_rs[col] = mtz_rs[col].astype("I") mtz_rs = mtz_rs.infer_mtz_dtypes() mtz_rs.write_mtz(filepath) return 1
# Legacy alias for backwards compatibility during transition MTZ = MTZReader