"""
4 CIF readers for different data types in crystallographic refinement.
This module provides 4 main classes:
- CIFReader: Base class for reading CIF/mmCIF files
- ReflectionCIFReader: For reading structure factor data (reflection data)
- ModelCIFReader: For reading atomic coordinate data (model structures)
- RestraintCIFReader: For reading chemical restraint dictionaries
Space groups are returned as gemmi.SpaceGroup objects for consistency
throughout torchref.
Specialized classes are typesave and should handle most edge cases in CIF files.
"""
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import gemmi
import numpy as np
import pandas as pd
[docs]
class CIFReader:
"""
A dictionary-like reader for CIF/mmCIF files.
Loops are stored as pandas DataFrames.
Other data is stored in a hierarchical dictionary structure.
Parameters
----------
filepath : str, optional
Path to CIF file to load immediately.
data_block : str, optional
Specific data block name to read (e.g., 'r1vlmsf').
If None and parse_all_blocks=False, reads the first data block.
If None and parse_all_blocks=True, reads all data blocks.
parse_all_blocks : bool, default False
If True, parse all data blocks and merge them into a single
dictionary (useful for restraint files). If False, parse only
the specified block or the first block.
Attributes
----------
data : dict
Dictionary storing parsed CIF data.
filepath : Path or None
Path to the loaded CIF file.
available_blocks : list
List of data block names found in the file.
"""
[docs]
def __init__(
self,
filepath: Optional[str] = None,
data_block: Optional[str] = None,
parse_all_blocks: bool = False,
):
"""
Initialize CIF reader.
Parameters
----------
filepath : str, optional
Path to CIF file to load immediately.
data_block : str, optional
Specific data block name to read.
parse_all_blocks : bool, default False
If True, parse all data blocks and merge.
"""
self.data = {}
self.filepath = None
self.data_block = data_block
self.parse_all_blocks = parse_all_blocks
self.available_blocks = []
self.verbose = 0
if filepath:
self.load(filepath)
[docs]
@classmethod
def from_string(cls, content: str, **kwargs) -> "CIFReader":
"""Create CIFReader from string content instead of a file."""
reader = cls(**kwargs)
reader._parse(content)
return reader
[docs]
def load(self, filepath: str):
"""
Load and parse a CIF file.
Parameters
----------
filepath : str
Path to CIF file.
"""
self.filepath = Path(filepath)
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
self._parse(content)
def _parse(self, content: str):
"""
Parse CIF file content.
Handles multiple data blocks. Behavior depends on parse_all_blocks flag:
- If parse_all_blocks=True: Parse all blocks and merge into single dictionary
- If data_block is specified: Only parse that specific block
- Otherwise: Parse only the first block
Parameters
----------
content : str
String content of CIF file.
"""
lines = content.split("\n")
i = 0
current_block = None
target_block_found = False
# First pass: find all data blocks
for line in lines:
stripped = line.strip()
if stripped.startswith("data_"):
block_name = stripped[5:].strip() # Remove 'data_' prefix
self.available_blocks.append(block_name)
# Determine parsing strategy
if self.parse_all_blocks:
# Parse all blocks - don't filter by block name
parse_all = True
if self.verbose > 0 and len(self.available_blocks) > 1:
print(f"Parsing all {len(self.available_blocks)} data blocks")
else:
parse_all = False
if self.data_block is None and self.available_blocks:
# No specific block requested, use first one
self.data_block = self.available_blocks[0]
if self.verbose > 0 and len(self.available_blocks) > 1:
print(f"Multiple data blocks found: {self.available_blocks}")
print(f"Reading block: {self.data_block}")
# Second pass: parse the target block(s)
while i < len(lines):
line = lines[i].strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
i += 1
continue
# Check for data block
if line.startswith("data_"):
block_name = line[5:].strip() # Remove 'data_' prefix
current_block = block_name
if parse_all:
# Continue parsing - don't skip any blocks
i += 1
continue
# Check if this is our target block
if self.data_block and block_name == self.data_block:
target_block_found = True
i += 1
continue
elif self.data_block and target_block_found:
# We've finished reading our target block, stop parsing
break
else:
# Skip this block
i += 1
continue
# Only parse if we're in parse_all mode OR in the target block
if parse_all or (not self.data_block or target_block_found):
# Check for loop
if line.startswith("loop_"):
i = self._parse_loop(lines, i + 1)
continue
# Parse single key-value pairs
if line.startswith("_"):
i = self._parse_keyvalue(lines, i)
continue
i += 1
# Validate that we found the requested block (if not parsing all)
if not parse_all and self.data_block and not target_block_found:
raise ValueError(
f"Data block '{self.data_block}' not found in CIF file.\n"
f"Available blocks: {self.available_blocks}"
)
def _parse_loop(self, lines: List[str], start_idx: int) -> int:
"""
Parse a loop structure into a pandas DataFrame.
Parameters
----------
lines : list of str
All lines of the file.
start_idx : int
Starting line index (after 'loop_').
Returns
-------
int
Index of the next line to process.
"""
# Collect column names
columns = []
i = start_idx
while i < len(lines):
line = lines[i].strip()
if not line or line.startswith("#"):
i += 1
continue
if line.startswith("_"):
columns.append(line)
i += 1
else:
break
if not columns:
return i
# Collect data rows
data_rows = []
current_row = []
in_multiline = False
multiline_value = []
while i < len(lines):
line = lines[i]
stripped = line.strip()
# Check if we've reached the end of the loop
if not in_multiline and (
not stripped
or stripped.startswith("_")
or stripped.startswith("loop_")
or stripped.startswith("data_")
):
if current_row:
data_rows.append(current_row)
break
# Handle multiline strings (starting with semicolon)
if not in_multiline and line.startswith(";"):
in_multiline = True
multiline_value = [line[1:]] # Remove leading semicolon
i += 1
continue
if in_multiline:
if line.startswith(";"):
# End of multiline string
in_multiline = False
current_row.append("\n".join(multiline_value))
multiline_value = []
# Check if row is complete
if len(current_row) == len(columns):
data_rows.append(current_row)
current_row = []
else:
multiline_value.append(line)
i += 1
continue
# Parse regular data line
if stripped and not stripped.startswith("#"):
tokens = self._tokenize_line(stripped)
for token in tokens:
current_row.append(token)
# Check if row is complete
if len(current_row) == len(columns):
data_rows.append(current_row)
current_row = []
i += 1
# Create DataFrame
if data_rows:
df = pd.DataFrame(data_rows, columns=columns)
# Store in hierarchical dictionary based on category
# Extract category from first column name (e.g., _atom_site.id -> atom_site)
if columns:
category = self._extract_category(columns[0])
if category:
self.data[category] = df
return i
def _parse_keyvalue(self, lines: List[str], start_idx: int) -> int:
"""
Parse a single key-value pair.
Parameters
----------
lines : list of str
All lines of the file.
start_idx : int
Starting line index.
Returns
-------
int
Index of the next line to process.
"""
line = lines[start_idx].strip()
# Handle multiline values
if start_idx + 1 < len(lines) and lines[start_idx + 1].startswith(";"):
key = line
value_lines = []
i = start_idx + 2
while i < len(lines):
if lines[i].startswith(";"):
break
value_lines.append(lines[i])
i += 1
value = "\n".join(value_lines)
self._store_keyvalue(key, value)
return i + 1
# Handle single line key-value
parts = line.split(None, 1)
if len(parts) == 2:
key, value = parts
# Remove quotes if present
if value.startswith("'") and value.endswith("'"):
value = value[1:-1]
elif value.startswith('"') and value.endswith('"'):
value = value[1:-1]
self._store_keyvalue(key, value)
return start_idx + 1
def _tokenize_line(self, line: str) -> List[str]:
"""
Tokenize a data line, handling quoted strings.
Parameters
----------
line : str
Line to tokenize.
Returns
-------
list of str
List of tokens.
"""
tokens = []
current_token = []
in_quotes = False
quote_char = None
i = 0
while i < len(line):
char = line[i]
# Handle quotes
if char in ('"', "'") and not in_quotes:
in_quotes = True
quote_char = char
i += 1
continue
if char == quote_char and in_quotes:
in_quotes = False
quote_char = None
if current_token:
tokens.append("".join(current_token))
current_token = []
i += 1
continue
# Handle whitespace outside quotes
if char.isspace() and not in_quotes:
if current_token:
tokens.append("".join(current_token))
current_token = []
i += 1
continue
current_token.append(char)
i += 1
if current_token:
tokens.append("".join(current_token))
return tokens
def _extract_category(self, key: str) -> str:
"""
Extract category from a CIF key.
Parameters
----------
key : str
CIF key (e.g., '_atom_site.id').
Returns
-------
str
Category name (e.g., 'atom_site').
"""
if key.startswith("_"):
key = key[1:]
if "." in key:
return key.split(".")[0]
return key
def _store_keyvalue(self, key: str, value: str):
"""
Store a key-value pair in the hierarchical dictionary.
Parameters
----------
key : str
CIF key (e.g., '_entry.id').
value : str
Value to store.
"""
# Extract category and attribute
if key.startswith("_"):
key = key[1:]
if "." in key:
category, attribute = key.split(".", 1)
if category not in self.data:
self.data[category] = {}
if isinstance(self.data[category], dict):
self.data[category][attribute] = value
else:
self.data[key] = value
[docs]
def write(self, filepath: str):
"""
Write the CIF data back to a file.
Parameters
----------
filepath : str
Output file path.
"""
with open(filepath, "w") as f:
f.write("data_structure\n")
f.write("#\n")
# Write single key-value pairs first
for category, content in sorted(self.data.items()):
if isinstance(content, dict):
for key, value in sorted(content.items()):
# Handle multiline values
if "\n" in str(value):
f.write(f"_{category}.{key}\n")
f.write(";\n")
f.write(str(value))
f.write("\n;\n")
else:
# Quote values with spaces
if " " in str(value):
f.write(f"_{category}.{key} '{value}'\n")
else:
f.write(f"_{category}.{key} {value}\n")
f.write("#\n")
# Write loops (DataFrames)
for category, content in sorted(self.data.items()):
if isinstance(content, pd.DataFrame):
f.write("loop_\n")
# Write column names
for col in content.columns:
f.write(f"{col}\n")
# Write data rows
for _, row in content.iterrows():
row_values = []
for val in row:
val_str = str(val)
# Quote values with spaces or special characters
if " " in val_str or any(c in val_str for c in ['"', "'"]):
row_values.append(f"'{val_str}'")
else:
row_values.append(val_str)
f.write(" ".join(row_values) + "\n")
f.write("#\n")
# Dictionary-like interface
[docs]
def __getitem__(self, key: str) -> Union[pd.DataFrame, Dict, Any]:
"""Get item by key."""
return self.data[key]
[docs]
def __setitem__(self, key: str, value: Union[pd.DataFrame, Dict, Any]):
"""Set item by key."""
self.data[key] = value
[docs]
def __contains__(self, key: str) -> bool:
"""Check if key exists."""
return key in self.data
[docs]
def __len__(self) -> int:
"""Return number of top-level categories."""
return len(self.data)
[docs]
def keys(self):
"""Return dictionary keys."""
return self.data.keys()
[docs]
def values(self):
"""Return dictionary values."""
return self.data.values()
[docs]
def items(self):
"""Return dictionary items."""
return self.data.items()
[docs]
def get(self, key: str, default=None):
"""Get item with default value."""
return self.data.get(key, default)
[docs]
def __repr__(self) -> str:
"""String representation."""
categories = list(self.keys())
loops = [k for k, v in self.items() if isinstance(v, pd.DataFrame)]
dicts = [k for k, v in self.items() if isinstance(v, dict)]
return (
f"CIFReader(categories={len(categories)}, "
f"loops={len(loops)}, "
f"key-value_groups={len(dicts)})"
)
[docs]
def summary(self):
"""Print a summary of the CIF contents."""
print(f"CIF File: {self.filepath}")
print(f"Total categories: {len(self.data)}")
print("\nLoops (DataFrames):")
for key, value in sorted(self.items()):
if isinstance(value, pd.DataFrame):
print(f" {key}: {len(value)} rows × {len(value.columns)} columns")
print("\nKey-Value Groups (Dictionaries):")
for key, value in sorted(self.items()):
if isinstance(value, dict):
print(f" {key}: {len(value)} items")
[docs]
class ReflectionCIFReader:
"""
Reader for structure factor CIF files (e.g., *-sf.cif from PDB).
Handles extraction of:
- Miller indices (h, k, l)
- Structure factor amplitudes (F) and uncertainties (σF)
- Intensities (I) and uncertainties (σI)
- Phases and figures of merit
- R-free flags
- Unit cell and space group metadata
Compatible with legacy MTZ reader interface:
reader = ReflectionCIFReader('7JI4-sf.cif').read()
data_dict, spacegroup, cell = reader()
Example:
reader = ReflectionCIFReader('7JI4-sf.cif')
refln_data = reader.get_reflection_data()
h, k, l = refln_data['h'], refln_data['k'], refln_data['l']
F_obs = refln_data['F_obs']
"""
[docs]
def __init__(
self, filepath: str, verbose: int = 0, data_block: Optional[str] = None
):
"""
Initialize and load structure factor CIF file.
Parameters
----------
filepath : str
Path to structure factor CIF file.
verbose : int, default 0
Verbosity level (0=silent, 1=info, 2=debug).
data_block : str, optional
Specific data block name to read (e.g., 'r1vlmsf').
If None, reads the first data block. Useful for files with
multiple datasets.
"""
self.filepath = Path(filepath)
self.verbose = verbose
self.cif_reader = CIFReader(filepath, data_block=data_block)
self.cif_reader.verbose = verbose
self._validate()
self._extract_data()
def _validate(self):
"""Validate that this is a structure factor CIF file."""
if "refln" not in self.cif_reader:
error_msg = (
f"File {self.filepath} does not contain reflection data (_refln loop) "
f"in the selected data block '{self.cif_reader.data_block}'.\n"
f"Available data blocks in file: {self.cif_reader.available_blocks}\n"
f"Available categories in selected block: {list(self.cif_reader.keys())}\n\n"
)
if len(self.cif_reader.available_blocks) > 1:
error_msg += (
f"This file contains multiple data blocks. Try specifying a different block:\n"
f" reader = ReflectionCIFReader('{self.filepath}', data_block='BLOCKNAME')\n"
f"where BLOCKNAME is one of: {self.cif_reader.available_blocks}"
)
raise ValueError(error_msg)
def _extract_data(self):
"""Extract data in legacy MTZ-compatible format."""
self.data = {}
# Extract reflection data
refln_df = self.get_reflection_data()
# Store combined HKL array (like MTZ reader) - this is the primary format
hkl = np.column_stack(
[
refln_df["h"].to_numpy(),
refln_df["k"].to_numpy(),
refln_df["l"].to_numpy(),
]
).astype(np.int32)
self.data["HKL"] = hkl
self.data["HKL_key"] = refln_df["hkl_key"]
# Store amplitudes if available (standardized keys matching MTZ reader)
if refln_df["F_obs"].notna().any():
self.data["F"] = refln_df["F_obs"].to_numpy().astype(np.float32)
self.data["F_col"] = refln_df["F_obs_key"]
if refln_df["sigma_F_obs"].notna().any():
self.data["SIGF"] = refln_df["sigma_F_obs"].to_numpy().astype(np.float32)
self.data["SIGF_col"] = refln_df["sigma_F_obs_key"]
# Store intensities if available (standardized keys matching MTZ reader)
if refln_df["I_obs"].notna().any():
self.data["I"] = refln_df["I_obs"].to_numpy().astype(np.float32)
self.data["I_col"] = refln_df["I_obs_key"]
if refln_df["sigma_I_obs"].notna().any():
self.data["SIGI"] = refln_df["sigma_I_obs"].to_numpy().astype(np.float32)
self.data["SIGI_col"] = refln_df["sigma_I_obs_key"]
# Store R-free flags if available (standardized keys matching MTZ reader)
if refln_df["free_flag"].notna().any():
rfree_characters = (
refln_df["free_flag"].str.lower().map({"f": 0, "x": -1, "o": 1})
)
percentage_work = (
(rfree_characters == 1).sum() / len(rfree_characters) * 100.0
)
percentage_test = (
(rfree_characters == 0).sum() / len(rfree_characters) * 100.0
)
if percentage_work < 0.9:
if self.verbose > 0:
print(
f"WARNING: R-free flags indicate only {percentage_work:.2f}% work reflections. Skipping R-free flags. >90% expected. Generating new Rfree flags"
)
self.data["R-free-source"] = "None"
elif percentage_test < 0.01:
if self.verbose > 0:
print(
f"WARNING: R-free flags indicate only {percentage_test:.2f}% test reflections. Skipping R-free flags. >1% expected. Generating new Rfree flags"
)
self.data["R-free-source"] = "None"
else:
self.data["R-free-flags"] = rfree_characters.to_numpy().astype(np.int32)
self.data["R-free-source"] = refln_df["free_flag_key"]
# Extract cell and spacegroup
self.cell = self.get_cell_parameters()
if self.cell is None:
raise ValueError(
f"Unit cell parameters not found in CIF file: {self.filepath}"
)
else:
self.cell = np.array(self.cell)
self.spacegroup = self.get_space_group()
if self.verbose > 1:
print(f"Loaded CIF file: {self.filepath}")
print(f" Reflections: {len(refln_df)}")
print(f" Has F: {'F' in self.data}")
print(f" Has I: {'I' in self.data}")
print(f" Has R-free: {'R-free-flags' in self.data}")
print(f" Cell: {self.cell}")
print(f" Spacegroup: {self.spacegroup}")
[docs]
def read(self, filepath: str = None):
"""
Read a CIF file (for compatibility with legacy interface).
Args:
filepath: Path to CIF file (optional, uses initialization path if not provided)
Returns:
self for method chaining
"""
if filepath is not None:
self.__init__(filepath, verbose=self.verbose)
return self
[docs]
def __call__(self) -> Tuple[Dict[str, np.ndarray], np.ndarray, gemmi.SpaceGroup]:
"""
Get data in legacy MTZ-compatible format.
Returns
-------
data : dict
Dictionary with extracted data arrays:
- 'h', 'k', 'l': Miller indices
- 'F', 'SIGF': Amplitudes and sigmas (if available)
- 'I', 'SIGI': Intensities and sigmas (if available)
- 'R-free-flags': R-free test set flags (if available)
cell : numpy.ndarray
Cell parameters [a, b, c, alpha, beta, gamma].
spacegroup : gemmi.SpaceGroup
Space group object.
"""
try:
return self.data, self.cell, self.spacegroup
except AttributeError as e:
raise ValueError(
"Data not loaded. Call read() first or provide filepath in __init__"
) from e
[docs]
def get_reflection_data(self) -> pd.DataFrame:
"""
Extract reflection data with standardized column names.
Returns
-------
pandas.DataFrame
DataFrame with columns:
- h, k, l: Miller indices
- F_obs, sigma_F_obs: Observed amplitudes (if available)
- I_obs, sigma_I_obs: Observed intensities (if available)
- phase, fom: Phase and figure of merit (if available)
- free_flag: R-free flags (if available)
Notes
-----
Missing columns will be filled with NaN or appropriate defaults.
"""
refln_df = self.cif_reader["refln"].copy()
# Standardize column names
result = pd.DataFrame()
# Miller indices (required)
result["h"], hkey = self._extract_numeric(
refln_df, ["_refln.index_h", "_refln.h"], required=True, target_type="int"
)
result["k"], kkey = self._extract_numeric(
refln_df, ["_refln.index_k", "_refln.k"], required=True, target_type="int"
)
result["l"], lkey = self._extract_numeric(
refln_df, ["_refln.index_l", "_refln.l"], required=True, target_type="int"
)
result["hkl_key"] = f"{hkey},{kkey},{lkey}"
# Structure factors - check for anomalous data first
F_plus_col = (
"_refln.pdbx_F_plus" if "_refln.pdbx_F_plus" in refln_df.columns else None
)
F_minus_col = (
"_refln.pdbx_F_minus" if "_refln.pdbx_F_minus" in refln_df.columns else None
)
sigF_plus_col = (
"_refln.pdbx_F_plus_sigma"
if "_refln.pdbx_F_plus_sigma" in refln_df.columns
else None
)
sigF_minus_col = (
"_refln.pdbx_F_minus_sigma"
if "_refln.pdbx_F_minus_sigma" in refln_df.columns
else None
)
if F_plus_col and F_minus_col:
# Average anomalous pairs since we're not doing phasing
F_plus = pd.to_numeric(
refln_df[F_plus_col].replace(["?", "."], np.nan), errors="coerce"
)
F_minus = pd.to_numeric(
refln_df[F_minus_col].replace(["?", "."], np.nan), errors="coerce"
)
# Average where both are present, otherwise use whichever is available
result["F_obs"] = F_plus.combine(
F_minus,
lambda x, y: (
(x + y) / 2
if pd.notna(x) and pd.notna(y)
else (x if pd.notna(x) else y)
),
fill_value=np.nan,
)
result["F_obs_key"] = f"{F_plus_col}+{F_minus_col}_averaged"
if sigF_plus_col and sigF_minus_col:
# Propagate uncertainties: sigma_avg = sqrt((sigma1^2 + sigma2^2) / 4)
sigF_plus = pd.to_numeric(
refln_df[sigF_plus_col].replace(["?", "."], np.nan), errors="coerce"
)
sigF_minus = pd.to_numeric(
refln_df[sigF_minus_col].replace(["?", "."], np.nan),
errors="coerce",
)
# When averaging two measurements, uncertainty is sqrt((s1^2 + s2^2)/n^2) = sqrt((s1^2 + s2^2)/4)
combined_sigma = np.sqrt((sigF_plus**2 + sigF_minus**2) / 4)
# Use whichever sigma is available if only one measurement present
result["sigma_F_obs"] = combined_sigma.combine_first(
sigF_plus
).combine_first(sigF_minus)
result["sigma_F_obs_key"] = f"{sigF_plus_col}+{sigF_minus_col}_averaged"
else:
result["sigma_F_obs"], sigma_F_obs_key = self._extract_numeric(
refln_df,
[
"_refln.F_meas_sigma_au",
"_refln.F_meas_sigma",
"_refln.F_squared_sigma",
"_refln.SIGF-obs",
],
target_type="float",
)
result["sigma_F_obs_key"] = sigma_F_obs_key
if self.verbose > 0:
n_both = ((pd.notna(F_plus)) & (pd.notna(F_minus))).sum()
n_plus_only = ((pd.notna(F_plus)) & (pd.isna(F_minus))).sum()
n_minus_only = ((pd.isna(F_plus)) & (pd.notna(F_minus))).sum()
print("Anomalous data detected: averaging F+ and F-")
print(f" Reflections with both F+/F-: {n_both}")
print(f" Reflections with F+ only: {n_plus_only}")
print(f" Reflections with F- only: {n_minus_only}")
else:
# Standard non-anomalous data
result["F_obs"], F_obs_key = self._extract_numeric(
refln_df,
[
"_refln.F_meas_au",
"_refln.F_meas",
"_refln.pdbx_F_plus",
"_refln.F_calc",
"_refln.F-obs",
"_refln.F_squared_meas",
],
target_type="float",
)
result["F_obs_key"] = F_obs_key
result["sigma_F_obs"], sigma_F_obs_key = self._extract_numeric(
refln_df,
[
"_refln.F_meas_sigma_au",
"_refln.F_meas_sigma",
"_refln.F_squared_sigma",
"_refln.SIGF-obs",
],
target_type="float",
)
result["sigma_F_obs_key"] = sigma_F_obs_key
# Intensities - check for anomalous intensities
I_plus_col = (
"_refln.pdbx_I_plus" if "_refln.pdbx_I_plus" in refln_df.columns else None
)
I_minus_col = (
"_refln.pdbx_I_minus" if "_refln.pdbx_I_minus" in refln_df.columns else None
)
sigI_plus_col = (
"_refln.pdbx_I_plus_sigma"
if "_refln.pdbx_I_plus_sigma" in refln_df.columns
else None
)
sigI_minus_col = (
"_refln.pdbx_I_minus_sigma"
if "_refln.pdbx_I_minus_sigma" in refln_df.columns
else None
)
if I_plus_col and I_minus_col:
# Average anomalous intensity pairs
I_plus = pd.to_numeric(
refln_df[I_plus_col].replace(["?", "."], np.nan), errors="coerce"
)
I_minus = pd.to_numeric(
refln_df[I_minus_col].replace(["?", "."], np.nan), errors="coerce"
)
result["I_obs"] = I_plus.combine(
I_minus,
lambda x, y: (
(x + y) / 2
if pd.notna(x) and pd.notna(y)
else (x if pd.notna(x) else y)
),
fill_value=np.nan,
)
result["I_obs_key"] = f"{I_plus_col}+{I_minus_col}_averaged"
if sigI_plus_col and sigI_minus_col:
sigI_plus = pd.to_numeric(
refln_df[sigI_plus_col].replace(["?", "."], np.nan), errors="coerce"
)
sigI_minus = pd.to_numeric(
refln_df[sigI_minus_col].replace(["?", "."], np.nan),
errors="coerce",
)
combined_sigma = np.sqrt((sigI_plus**2 + sigI_minus**2) / 4)
result["sigma_I_obs"] = combined_sigma.combine_first(
sigI_plus
).combine_first(sigI_minus)
result["sigma_I_obs_key"] = f"{sigI_plus_col}+{sigI_minus_col}_averaged"
else:
result["sigma_I_obs"], sigIobskey = self._extract_numeric(
refln_df,
[
"_refln.intensity_sigma",
"_refln.I_sigma",
"_refln.SIGI-obs",
"_refln.pdbx_I_sigma",
],
target_type="float",
)
result["sigma_I_obs_key"] = sigIobskey
if self.verbose > 0:
n_both = ((pd.notna(I_plus)) & (pd.notna(I_minus))).sum()
n_plus_only = ((pd.notna(I_plus)) & (pd.isna(I_minus))).sum()
n_minus_only = ((pd.isna(I_plus)) & (pd.notna(I_minus))).sum()
print("Anomalous intensity data detected: averaging I+ and I-")
print(f" Reflections with both I+/I-: {n_both}")
print(f" Reflections with I+ only: {n_plus_only}")
print(f" Reflections with I- only: {n_minus_only}")
else:
# Standard non-anomalous intensities
result["I_obs"], Iobskey = self._extract_numeric(
refln_df,
[
"_refln.intensity_meas",
"_refln.I_meas",
"_refln.pdbx_I_plus",
"_refln.I-obs",
"_refln.pdbx_I",
],
target_type="float",
)
result["I_obs_key"] = Iobskey
result["sigma_I_obs"], sigIobskey = self._extract_numeric(
refln_df,
[
"_refln.intensity_sigma",
"_refln.I_sigma",
"_refln.pdbx_I_plus_sigma",
"_refln.SIGI-obs",
"_refln.pdbx_I_sigma",
],
target_type="float",
)
result["sigma_I_obs_key"] = sigIobskey
# Phase information
result["phase"], phase_key = self._extract_numeric(
refln_df,
["_refln.phase_meas", "_refln.phase_calc", "_refln.pdbx_PHIB"],
target_type="float",
)
result["phase_key"] = phase_key
result["fom"], fom_key = self._extract_numeric(
refln_df, ["_refln.fom", "_refln.pdbx_FOM"], target_type="float"
)
result["fom_key"] = fom_key
# R-free flags
result["free_flag"], free_flag_key = self._extract_numeric(
refln_df,
["_refln.status", "_refln.pdbx_r_free_flag", "_refln.free_flag"],
target_type="None",
)
result["free_flag_key"] = free_flag_key
return result
def _extract_numeric(
self,
df: pd.DataFrame,
possible_cols: List[str],
required: bool = False,
target_type: str = "float",
) -> pd.Series:
"""
Extract numeric data from DataFrame, trying multiple column names.
Parameters
----------
df : pandas.DataFrame
Source DataFrame.
possible_cols : list of str
List of possible column names to try.
required : bool, default False
If True, raise error if no column found.
target_type : str, default 'float'
Target data type ('int', 'float', or 'None' for string).
Returns
-------
tuple
(Series with numeric data, column name used) or (NaN series, 'None')
if not found.
"""
for col in possible_cols:
if col in df.columns:
try:
# Handle '?' as missing data
data = df[col].replace(["?", "."], np.nan)
if target_type == "int":
return (
pd.to_numeric(data, errors="coerce").fillna(0).astype(int),
col,
)
elif target_type == "float":
return pd.to_numeric(data, errors="coerce"), col
elif target_type == "None":
return data.astype(str), col
else:
print(f"Unknown target_type: {target_type}")
except Exception:
continue
if required:
raise ValueError(
f"Required column not found. Tried: {possible_cols}\n"
f"Available columns: {list(df.columns)}"
)
# Return NaN series
return pd.Series([np.nan] * len(df)), "None"
[docs]
def has_miller_indices(self) -> bool:
"""Check if file contains Miller indices."""
if "refln" not in self.cif_reader:
return False
df = self.cif_reader["refln"]
h_cols = ["_refln.index_h", "_refln.h"]
return any(col in df.columns for col in h_cols)
[docs]
def has_amplitudes(self) -> bool:
"""Check if file contains structure factor amplitudes."""
if "refln" not in self.cif_reader:
return False
df = self.cif_reader["refln"]
f_cols = [
"_refln.F_meas_au",
"_refln.F_meas",
"_refln.pdbx_F_plus",
"_refln.F_calc",
"_refln.F-obs",
]
return any(col in df.columns for col in f_cols)
[docs]
def has_intensities(self) -> bool:
"""Check if file contains intensity measurements."""
if "refln" not in self.cif_reader:
return False
df = self.cif_reader["refln"]
i_cols = [
"_refln.intensity_meas",
"_refln.I_meas",
"_refln.pdbx_I_plus",
"_refln.I-obs",
"_refln.pdbx_I",
]
return any(col in df.columns for col in i_cols)
[docs]
def has_phases(self) -> bool:
"""Check if file contains phase information."""
if "refln" not in self.cif_reader:
return False
df = self.cif_reader["refln"]
phase_cols = ["_refln.phase_meas", "_refln.phase_calc", "_refln.pdbx_PHIB"]
return any(col in df.columns for col in phase_cols)
[docs]
def has_rfree_flags(self) -> bool:
"""Check if file contains R-free flags."""
if "refln" not in self.cif_reader:
return False
df = self.cif_reader["refln"]
flag_cols = ["_refln.status", "_refln.pdbx_r_free_flag", "_refln.free_flag"]
return any(col in df.columns for col in flag_cols)
[docs]
def get_miller_indices(self) -> Optional[np.ndarray]:
"""
Get Miller indices as Nx3 array.
Returns:
Array of shape (N, 3) with h, k, l indices
"""
data = self.get_reflection_data()
if data is None or "h" not in data.columns:
return None
return data[["h", "k", "l"]].values
[docs]
def get_amplitudes(self) -> Optional[Dict[str, np.ndarray]]:
"""
Get structure factor amplitudes and uncertainties.
Returns:
Dict with keys 'F' and 'sigma_F', or None if not available
"""
data = self.get_reflection_data()
if data is None or "F_obs" not in data.columns:
return None
if data["F_obs"].isna().all():
return None
return {"F": data["F_obs"].values, "sigma_F": data["sigma_F_obs"].values}
[docs]
def get_intensities(self) -> Optional[Dict[str, np.ndarray]]:
"""
Get intensities and uncertainties.
Returns:
Dict with keys 'I' and 'sigma_I', or None if not available
"""
data = self.get_reflection_data()
if data is None or "I_obs" not in data.columns:
return None
if data["I_obs"].isna().all():
return None
return {"I": data["I_obs"].values, "sigma_I": data["sigma_I_obs"].values}
[docs]
def get_cell_parameters(self) -> Optional[List[float]]:
"""
Extract unit cell parameters [a, b, c, alpha, beta, gamma].
Returns:
List of 6 floats, or None if not found
"""
if "cell" not in self.cif_reader:
return None
cell_data = self.cif_reader["cell"]
try:
a = float(self._get_value(cell_data, ["_cell.length_a", "length_a"], "1.0"))
b = float(self._get_value(cell_data, ["_cell.length_b", "length_b"], "1.0"))
c = float(self._get_value(cell_data, ["_cell.length_c", "length_c"], "1.0"))
alpha = float(
self._get_value(cell_data, ["_cell.angle_alpha", "angle_alpha"], "90.0")
)
beta = float(
self._get_value(cell_data, ["_cell.angle_beta", "angle_beta"], "90.0")
)
gamma = float(
self._get_value(cell_data, ["_cell.angle_gamma", "angle_gamma"], "90.0")
)
return [a, b, c, alpha, beta, gamma]
except Exception:
return None
[docs]
def get_space_group(self) -> str:
"""
Extract space group name.
Returns
-------
str
Space group name string. Returns "P 1" if not found.
"""
sg_name = "P 1"
if "symmetry" in self.cif_reader:
sym_data = self.cif_reader["symmetry"]
sg_name = self._get_value(
sym_data,
[
"_symmetry.space_group_name_H-M",
"space_group_name_H-M",
"_space_group.name_H-M_alt",
],
"P 1",
)
# Validate the name by trying to parse it
try:
gemmi.SpaceGroup(sg_name)
return sg_name
except Exception:
try:
gemmi.SpaceGroup(sg_name.replace(" ", ""))
return sg_name.replace(" ", "")
except Exception:
return "P 1"
def _get_value(self, data, possible_keys: List[str], default: Any = None) -> Any:
"""Get value from DataFrame or dict, trying multiple keys."""
if isinstance(data, pd.DataFrame):
for key in possible_keys:
if key in data.columns and len(data) > 0:
return data[key].iloc[0]
elif isinstance(data, dict):
for key in possible_keys:
if key in data:
return data[key]
return default
[docs]
class ModelCIFReader:
"""
Reader for model/structure CIF files (e.g., *.cif from PDB).
Handles extraction of:
- Atomic coordinates and properties
- Alternative conformations
- Anisotropic displacement parameters
- Unit cell and space group
Compatible with legacy PDB reader interface:
reader = ModelCIFReader('3E98.cif').read()
dataframe, cell, spacegroup = reader()
Example:
reader = ModelCIFReader('3E98.cif')
atom_df = reader.get_atom_data()
cell = reader.get_cell_parameters()
"""
[docs]
def __init__(self, filepath: str, verbose: int = 0):
"""
Initialize and load model CIF file.
Parameters
----------
filepath : str
Path to model CIF file.
verbose : int, default 0
Verbosity level (0=silent, 1=info, 2=debug).
"""
self.filepath = Path(filepath)
self.verbose = verbose
self.cif = CIFReader(filepath)
self._validate()
self._extract_data()
def _validate(self):
"""Validate that this is a model CIF file."""
if "atom_site" not in self.cif.data:
raise ValueError(
f"File {self.filepath} does not contain atomic coordinate data (_atom_site loop).\n"
f"This does not appear to be a model CIF file.\n"
f"Available data blocks: {list(self.cif.data.keys())}"
)
def _extract_data(self):
"""Extract data in legacy PDB-compatible format."""
# Get atom data as DataFrame
self.dataframe = self.get_atom_data()
# Extract cell and spacegroup
cell_params = self.get_cell_parameters()
if cell_params is None:
self.cell = None
else:
self.cell = cell_params
self.spacegroup = self.get_space_group()
# Store as DataFrame attributes (like legacy PDB reader)
self.dataframe.attrs["cell"] = self.cell
self.dataframe.attrs["spacegroup"] = self.spacegroup
self.dataframe.attrs["z"] = None # CIF files typically don't have Z value
if self.verbose > 1:
print(f"Loaded CIF model file: {self.filepath}")
print(f" Atoms: {len(self.dataframe)}")
print(f" Cell: {self.cell}")
print(f" Spacegroup: {self.spacegroup}")
[docs]
def read(self, filepath: str = None):
"""
Read a CIF file (for compatibility with legacy interface).
Parameters
----------
filepath : str, optional
Path to CIF file. Uses initialization path if not provided.
Returns
-------
ModelCIFReader
Self for method chaining.
"""
if filepath is not None:
self.__init__(filepath, verbose=self.verbose)
return self
[docs]
def __call__(self) -> Tuple[pd.DataFrame, List[float], gemmi.SpaceGroup]:
"""
Get data in legacy PDB-compatible format.
Returns
-------
dataframe : pandas.DataFrame
Atom data with columns: ATOM, serial, name, altloc, resname, chainid,
resseq, icode, x, y, z, occupancy, tempfactor, element, charge,
anisou_flag, u11, u22, u33, u12, u13, u23.
cell : list
Cell parameters [a, b, c, alpha, beta, gamma].
spacegroup : gemmi.SpaceGroup
Space group object.
"""
try:
return self.dataframe, self.cell, self.spacegroup
except AttributeError as e:
raise ValueError(
"Data not loaded. Call read() first or provide filepath in __init__"
) from e
[docs]
def get_atom_data(self) -> pd.DataFrame:
"""
Extract atomic coordinate data in PDB-compatible format.
Returns
-------
pandas.DataFrame
DataFrame with columns matching PDB format:
- ATOM, serial, name, altloc, resname, chainid, resseq, icode
- x, y, z, occupancy, tempfactor
- element, charge
- anisou_flag, u11, u22, u33, u12, u13, u23
"""
atom_df = self.cif.data["atom_site"].copy()
result = pd.DataFrame()
# Record type (ATOM or HETATM)
result["ATOM"] = self._extract_string(
atom_df, ["_atom_site.group_PDB"], default="ATOM"
)
# Serial number
result["serial"] = self._extract_int(
atom_df, ["_atom_site.id"], default_range=True
)
# Atom identification
result["name"] = self._extract_string(
atom_df,
["_atom_site.label_atom_id", "_atom_site.auth_atom_id"],
required=True,
)
result["altloc"] = self._extract_string(
atom_df, ["_atom_site.label_alt_id"], default="", replace_dot=True
)
result["resname"] = self._extract_string(
atom_df,
["_atom_site.label_comp_id", "_atom_site.auth_comp_id"],
required=True,
)
result["chainid"] = self._extract_string(
atom_df,
["_atom_site.auth_asym_id", "_atom_site.label_asym_id"],
default="",
replace_dot=True,
)
result["resseq"] = self._extract_int(
atom_df, ["_atom_site.auth_seq_id", "_atom_site.label_seq_id"], default=0
)
result["icode"] = self._extract_string(
atom_df, ["_atom_site.pdbx_PDB_ins_code"], default="", replace_dot=True
)
# Coordinates
result["x"] = self._extract_float(
atom_df, ["_atom_site.Cartn_x"], required=True
)
result["y"] = self._extract_float(
atom_df, ["_atom_site.Cartn_y"], required=True
)
result["z"] = self._extract_float(
atom_df, ["_atom_site.Cartn_z"], required=True
)
# Properties
result["occupancy"] = self._extract_float(
atom_df, ["_atom_site.occupancy"], default=1.0
)
result["tempfactor"] = self._extract_float(
atom_df, ["_atom_site.B_iso_or_equiv"], default=20.0
)
# Element and charge
result["element"] = self._extract_string(
atom_df, ["_atom_site.type_symbol"], required=True
)
result["charge"] = self._extract_int(
atom_df, ["_atom_site.pdbx_formal_charge"], default=0
)
# Model number (for multi-model structures / ensembles)
result["model_num"] = self._extract_int(
atom_df, ["_atom_site.pdbx_PDB_model_num"], default=1
)
# Anisotropic displacement parameters
aniso_cols = [
"_atom_site.aniso_U[1][1]",
"_atom_site.aniso_U[2][2]",
"_atom_site.aniso_U[3][3]",
"_atom_site.aniso_U[1][2]",
"_atom_site.aniso_U[1][3]",
"_atom_site.aniso_U[2][3]",
]
if all(col in atom_df.columns for col in aniso_cols):
result["u11"] = pd.to_numeric(
atom_df["_atom_site.aniso_U[1][1]"], errors="coerce"
)
result["u22"] = pd.to_numeric(
atom_df["_atom_site.aniso_U[2][2]"], errors="coerce"
)
result["u33"] = pd.to_numeric(
atom_df["_atom_site.aniso_U[3][3]"], errors="coerce"
)
result["u12"] = pd.to_numeric(
atom_df["_atom_site.aniso_U[1][2]"], errors="coerce"
)
result["u13"] = pd.to_numeric(
atom_df["_atom_site.aniso_U[1][3]"], errors="coerce"
)
result["u23"] = pd.to_numeric(
atom_df["_atom_site.aniso_U[2][3]"], errors="coerce"
)
result["anisou_flag"] = ~pd.isna(result["u11"])
else:
result["u11"] = np.nan
result["u22"] = np.nan
result["u33"] = np.nan
result["u12"] = np.nan
result["u13"] = np.nan
result["u23"] = np.nan
result["anisou_flag"] = False
# Add index column for compatibility with legacy PDB format
result["index"] = np.arange(len(result), dtype=int)
result["element"] = result["element"].str.strip().str.capitalize()
return result
[docs]
def get_atom_data_by_model(self) -> Dict[int, pd.DataFrame]:
"""
Split atom data by ``pdbx_PDB_model_num``.
For single-model files, returns ``{1: dataframe}``.
For multi-model files, returns one DataFrame per model number.
Returns
-------
dict of int -> pandas.DataFrame
Mapping of model number to atom DataFrame.
"""
df = self.get_atom_data()
if "model_num" not in df.columns:
return {1: df}
return {
int(num): group.reset_index(drop=True)
for num, group in df.groupby("model_num")
}
def _extract_string(
self,
df: pd.DataFrame,
possible_cols: List[str],
required: bool = False,
default: str = "",
replace_dot: bool = False,
) -> pd.Series:
"""Extract string column with fallbacks."""
for col in possible_cols:
if col in df.columns:
data = df[col].fillna(default)
if replace_dot:
data = data.replace([".", "?"], default)
return data
if required:
raise ValueError(f"Required column not found. Tried: {possible_cols}")
return pd.Series([default] * len(df))
def _extract_float(
self,
df: pd.DataFrame,
possible_cols: List[str],
required: bool = False,
default: float = np.nan,
) -> pd.Series:
"""Extract float column with fallbacks."""
for col in possible_cols:
if col in df.columns:
return pd.to_numeric(
df[col].replace(["?", "."], np.nan), errors="coerce"
).fillna(default)
if required:
raise ValueError(f"Required column not found. Tried: {possible_cols}")
return pd.Series([default] * len(df))
def _extract_int(
self,
df: pd.DataFrame,
possible_cols: List[str],
required: bool = False,
default: int = 0,
default_range: bool = False,
) -> pd.Series:
"""Extract integer column with fallbacks."""
for col in possible_cols:
if col in df.columns:
# Replace missing values and convert to numeric
# Use mask to avoid FutureWarning about downcasting in replace
series = df[col].copy()
series = series.mask(series.isin(["?", "."]), default)
return (
pd.to_numeric(series, errors="coerce").fillna(default).astype(int)
)
if required:
raise ValueError(f"Required column not found. Tried: {possible_cols}")
if default_range:
return pd.Series(range(1, len(df) + 1))
return pd.Series([default] * len(df))
[docs]
def get_cell_parameters(self) -> Optional[List[float]]:
"""Extract unit cell parameters [a, b, c, alpha, beta, gamma]."""
if "cell" not in self.cif.data:
return None
cell_data = self.cif.data["cell"]
try:
a = float(
self._get_first_value(cell_data, ["_cell.length_a", "length_a"], "1.0")
)
b = float(
self._get_first_value(cell_data, ["_cell.length_b", "length_b"], "1.0")
)
c = float(
self._get_first_value(cell_data, ["_cell.length_c", "length_c"], "1.0")
)
alpha = float(
self._get_first_value(
cell_data, ["_cell.angle_alpha", "angle_alpha"], "90.0"
)
)
beta = float(
self._get_first_value(
cell_data, ["_cell.angle_beta", "angle_beta"], "90.0"
)
)
gamma = float(
self._get_first_value(
cell_data, ["_cell.angle_gamma", "angle_gamma"], "90.0"
)
)
return [a, b, c, alpha, beta, gamma]
except Exception:
return None
[docs]
def get_space_group(self) -> str:
"""
Extract space group name.
Returns
-------
str
Space group name string. Returns "P 1" if not found.
"""
sg_name = "P 1"
if "symmetry" in self.cif.data:
sym_data = self.cif.data["symmetry"]
sg_name = self._get_first_value(
sym_data,
[
"_symmetry.space_group_name_H-M",
"space_group_name_H-M",
"_space_group.name_H-M_alt",
],
"P 1",
)
# Validate the name by trying to parse it
try:
gemmi.SpaceGroup(sg_name)
return sg_name
except Exception:
try:
gemmi.SpaceGroup(sg_name.replace(" ", ""))
return sg_name.replace(" ", "")
except Exception:
return "P 1"
def _get_first_value(
self, data, possible_keys: List[str], default: Any = None
) -> Any:
"""Get value from DataFrame or dict, trying multiple keys."""
if isinstance(data, pd.DataFrame):
for key in possible_keys:
if key in data.columns and len(data) > 0:
return data[key].iloc[0]
elif isinstance(data, dict):
for key in possible_keys:
if key in data:
return data[key]
return default
# Convenience methods for testing
[docs]
def has_coordinates(self) -> bool:
"""Check if atomic coordinates are available."""
return "atom_site" in self.cif.data
[docs]
def has_cell_parameters(self) -> bool:
"""Check if unit cell parameters are available."""
return "cell" in self.cif.data
[docs]
def has_space_group(self) -> bool:
"""Check if space group information is available."""
return "symmetry" in self.cif.data
[docs]
def has_occupancy(self) -> bool:
"""Check if occupancy data is available."""
if "atom_site" not in self.cif.data:
return False
return "_atom_site.occupancy" in self.cif.data["atom_site"].columns
[docs]
def has_bfactor(self) -> bool:
"""Check if B-factor/temperature factor data is available."""
if "atom_site" not in self.cif.data:
return False
return "_atom_site.B_iso_or_equiv" in self.cif.data["atom_site"].columns
[docs]
def has_anisotropic_data(self) -> bool:
"""Check if anisotropic displacement parameters are available."""
if "atom_site" not in self.cif.data:
return False
aniso_cols = [
"_atom_site.aniso_U[1][1]",
"_atom_site.aniso_U[2][2]",
"_atom_site.aniso_U[3][3]",
]
return all(col in self.cif.data["atom_site"].columns for col in aniso_cols)
[docs]
def get_coordinates(self) -> Optional[np.ndarray]:
"""
Extract atomic coordinates as numpy array.
Returns
-------
numpy.ndarray or None
Nx3 array of [x, y, z] coordinates, or None if not available.
"""
if not self.has_coordinates():
return None
atom_data = self.get_atom_data()
return atom_data[["x", "y", "z"]].values
[docs]
def get_atom_info(self) -> pd.DataFrame:
"""
Extract atom information (without coordinates).
Returns
-------
pandas.DataFrame
DataFrame with atom names, residue info, elements, etc.
"""
atom_data = self.get_atom_data()
return atom_data[
[
"serial",
"name",
"altloc",
"resname",
"chainid",
"resseq",
"icode",
"element",
"charge",
]
]
[docs]
class RestraintCIFReader:
"""
Reader for chemical restraint dictionary CIF files (e.g., from monomer library).
Handles extraction of:
- Bond restraints (ideal lengths and ESDs)
- Angle restraints
- Torsion/dihedral restraints
- Planarity restraints
- Chirality definitions
Validates that the file contains proper restraint parameters
(not just structure definitions).
Example:
reader = RestraintCIFReader('external_monomer_library/a/ALA.cif')
comp_data = reader.get_all_restraints()
bond_df = comp_data['ALA']['bonds']
"""
[docs]
def __init__(self, filepath: str):
"""
Initialize and load restraint CIF file.
Parameters
----------
filepath : str
Path to restraint dictionary CIF file.
"""
self.filepath = Path(filepath)
# Use parse_all_blocks=True because restraint files often have multiple blocks
# (e.g., data_comp_list and data_comp_PRO)
self.cif = CIFReader(filepath, parse_all_blocks=True)
self.compounds = self._extract_compounds()
self._validate()
def _extract_compounds(self) -> List[str]:
"""
Extract list of compound IDs from the file.
Returns
-------
list of str
List of compound IDs (e.g., ['ALA'], ['2BA']).
"""
compounds = []
# Check for comp_list (monomer library format)
if "comp_list" in self.cif.data:
df = self.cif.data["comp_list"]
if "id" in df.columns:
compounds = df["id"].tolist()
elif "_chem_comp.id" in df.columns:
compounds = df["_chem_comp.id"].tolist()
# Check for chem_comp (eLBOW/phenix format)
if not compounds and "chem_comp" in self.cif.data:
df = self.cif.data["chem_comp"]
if "_chem_comp.id" in df.columns and len(df) > 0:
compounds = df["_chem_comp.id"].tolist()
elif "id" in df.columns and len(df) > 0:
compounds = df["id"].tolist()
# If no comp_list/chem_comp, look for single compound definition
if not compounds:
for key in self.cif.data.keys():
if key.startswith("comp_") and key != "comp_list":
# Extract compound ID from key pattern
# Keys like 'comp_ALA', 'comp_bond', etc.
continue
# Alternative: look in raw data for data_ blocks
# For now, try to extract from available keys
if "comp" in self.cif.data:
df = self.cif.data["comp"]
if "id" in df.columns and len(df) > 0:
compounds = [df["id"].iloc[0]]
elif "_chem_comp.id" in df.columns and len(df) > 0:
compounds = [df["_chem_comp.id"].iloc[0]]
return compounds
def _validate(self):
"""
Validate that this is a proper restraint file with geometry parameters.
"""
if not self.compounds:
# Try to infer compound ID from filename
comp_id = self.filepath.stem
if comp_id:
self.compounds = [comp_id]
# Check for bond restraints with proper parameters
# Try both naming conventions: comp_bond and chem_comp_bond
bond_df = None
if "comp_bond" in self.cif.data:
bond_df = self.cif.data["comp_bond"]
elif "chem_comp_bond" in self.cif.data:
bond_df = self.cif.data["chem_comp_bond"]
if bond_df is not None:
required_cols = ["value_dist", "value_dist_esd"]
missing_cols = [
col
for col in required_cols
if not any(col in c for c in bond_df.columns)
]
if missing_cols:
raise ValueError(
f"Restraint file {self.filepath} is missing required bond parameters.\n"
f"Missing columns: {missing_cols}\n"
f"Available columns: {list(bond_df.columns)}\n\n"
f"This appears to be a structure definition file (from PDB) rather than\n"
f"a proper restraint dictionary. Restraint files must include ideal\n"
f"geometry parameters such as 'value_dist' and 'value_dist_esd'.\n\n"
f"Solution: Use files from the CCP4 Monomer Library\n"
f"which contain proper restraint parameters."
)
else:
raise ValueError(
f"File {self.filepath} does not contain bond restraint data (_chem_comp_bond).\n"
f"Available data blocks: {list(self.cif.data.keys())}\n\n"
f"This is not a valid restraint dictionary file."
)
[docs]
def get_all_restraints(self) -> Dict[str, Dict[str, pd.DataFrame]]:
"""
Extract all restraint data for all compounds with standardized column names.
Returns
-------
dict
Dictionary mapping compound ID to dict of restraint types::
{
'ALA': {
'bonds': DataFrame(atom1, atom2, value, sigma),
'angles': DataFrame(atom1, atom2, atom3, value, sigma),
'torsions': DataFrame(atom1, atom2, atom3, atom4, value, sigma, periodicity),
'planes': DataFrame(atom, plane_id),
'chirals': DataFrame(atom_centre, atom1, atom2, atom3, volume_sign)
},
...
}
"""
result = {}
for comp_id in self.compounds:
result[comp_id] = self.get_compound_restraints(comp_id)
# If no compounds found, try to get data directly
if not result:
comp_id = self.filepath.stem
raw_bonds = self.cif.data.get(
"comp_bond", self.cif.data.get("chem_comp_bond", pd.DataFrame())
)
raw_angles = self.cif.data.get(
"comp_angle", self.cif.data.get("chem_comp_angle", pd.DataFrame())
)
raw_torsions = self.cif.data.get(
"comp_tor", self.cif.data.get("chem_comp_tor", pd.DataFrame())
)
raw_planes = self.cif.data.get(
"comp_plane_atom",
self.cif.data.get("chem_comp_plane_atom", pd.DataFrame()),
)
raw_chirals = self.cif.data.get(
"comp_chir", self.cif.data.get("chem_comp_chir", pd.DataFrame())
)
raw_atoms = self.cif.data.get(
"comp_atom", self.cif.data.get("chem_comp_atom", pd.DataFrame())
)
result[comp_id] = {
"bonds": self._standardize_bonds(raw_bonds),
"angles": self._standardize_angles(raw_angles),
"torsions": self._standardize_torsions(raw_torsions),
"planes": self._standardize_planes(raw_planes),
"chirals": self._standardize_chirals(raw_chirals),
"atoms": self._standardize_atoms(raw_atoms),
}
return result
[docs]
def get_compound_restraints(self, comp_id: str) -> Dict[str, pd.DataFrame]:
"""
Extract restraints for a specific compound with standardized column names.
Parameters
----------
comp_id : str
Compound identifier (e.g., 'ALA').
Returns
-------
dict
Dictionary of restraint DataFrames with standardized columns::
{
'bonds': DataFrame(atom1, atom2, value, sigma)
'angles': DataFrame(atom1, atom2, atom3, value, sigma)
'torsions': DataFrame(atom1, atom2, atom3, atom4, value, sigma, periodicity)
'planes': DataFrame(atom, plane_id)
'chirals': DataFrame(atom_centre, atom1, atom2, atom3, volume_sign)
'atoms': DataFrame(atom_id, type_symbol, charge, etc.)
}
"""
restraints = {}
# Extract and standardize each restraint type
raw_bonds = self._filter_by_comp(
self.cif.data.get(
"comp_bond", self.cif.data.get("chem_comp_bond", pd.DataFrame())
),
comp_id,
)
restraints["bonds"] = self._standardize_bonds(raw_bonds)
raw_angles = self._filter_by_comp(
self.cif.data.get(
"comp_angle", self.cif.data.get("chem_comp_angle", pd.DataFrame())
),
comp_id,
)
restraints["angles"] = self._standardize_angles(raw_angles)
raw_torsions = self._filter_by_comp(
self.cif.data.get(
"comp_tor", self.cif.data.get("chem_comp_tor", pd.DataFrame())
),
comp_id,
)
restraints["torsions"] = self._standardize_torsions(raw_torsions)
raw_planes = self._filter_by_comp(
self.cif.data.get(
"comp_plane_atom",
self.cif.data.get("chem_comp_plane_atom", pd.DataFrame()),
),
comp_id,
)
restraints["planes"] = self._standardize_planes(raw_planes)
raw_chirals = self._filter_by_comp(
self.cif.data.get(
"comp_chir", self.cif.data.get("chem_comp_chir", pd.DataFrame())
),
comp_id,
)
restraints["chirals"] = self._standardize_chirals(raw_chirals)
raw_atoms = self._filter_by_comp(
self.cif.data.get(
"comp_atom", self.cif.data.get("chem_comp_atom", pd.DataFrame())
),
comp_id,
)
restraints["atoms"] = self._standardize_atoms(raw_atoms)
return restraints
def _standardize_bonds(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize bond restraint columns to: atom1, atom2, value, sigma."""
if df.empty:
return pd.DataFrame(columns=["atom1", "atom2", "value", "sigma"])
result = pd.DataFrame()
result["atom1"] = self._extract_col(
df, ["atom_id_1", "_chem_comp_bond.atom_id_1", "atom1"]
)
result["atom2"] = self._extract_col(
df, ["atom_id_2", "_chem_comp_bond.atom_id_2", "atom2"]
)
result["value"] = pd.to_numeric(
self._extract_col(
df, ["value_dist", "_chem_comp_bond.value_dist", "value"]
),
errors="coerce",
)
result["sigma"] = pd.to_numeric(
self._extract_col(
df, ["value_dist_esd", "_chem_comp_bond.value_dist_esd", "sigma", "esd"]
),
errors="coerce",
)
return result
def _standardize_angles(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize angle restraint columns to: atom1, atom2, atom3, value, sigma."""
if df.empty:
return pd.DataFrame(columns=["atom1", "atom2", "atom3", "value", "sigma"])
result = pd.DataFrame()
result["atom1"] = self._extract_col(
df, ["atom_id_1", "_chem_comp_angle.atom_id_1", "atom1"]
)
result["atom2"] = self._extract_col(
df, ["atom_id_2", "_chem_comp_angle.atom_id_2", "atom2"]
)
result["atom3"] = self._extract_col(
df, ["atom_id_3", "_chem_comp_angle.atom_id_3", "atom3"]
)
result["value"] = pd.to_numeric(
self._extract_col(
df, ["value_angle", "_chem_comp_angle.value_angle", "value"]
),
errors="coerce",
)
result["sigma"] = pd.to_numeric(
self._extract_col(
df,
["value_angle_esd", "_chem_comp_angle.value_angle_esd", "sigma", "esd"],
),
errors="coerce",
)
return result
def _standardize_torsions(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize torsion restraint columns to: atom1, atom2, atom3, atom4, value, sigma, periodicity."""
if df.empty:
return pd.DataFrame(
columns=[
"atom1",
"atom2",
"atom3",
"atom4",
"value",
"sigma",
"periodicity",
]
)
result = pd.DataFrame()
result["atom1"] = self._extract_col(
df, ["atom_id_1", "_chem_comp_tor.atom_id_1", "atom1"]
)
result["atom2"] = self._extract_col(
df, ["atom_id_2", "_chem_comp_tor.atom_id_2", "atom2"]
)
result["atom3"] = self._extract_col(
df, ["atom_id_3", "_chem_comp_tor.atom_id_3", "atom3"]
)
result["atom4"] = self._extract_col(
df, ["atom_id_4", "_chem_comp_tor.atom_id_4", "atom4"]
)
result["value"] = pd.to_numeric(
self._extract_col(
df, ["value_angle", "_chem_comp_tor.value_angle", "value"]
),
errors="coerce",
)
result["sigma"] = pd.to_numeric(
self._extract_col(
df,
["value_angle_esd", "_chem_comp_tor.value_angle_esd", "sigma", "esd"],
),
errors="coerce",
)
result["periodicity"] = pd.to_numeric(
self._extract_col(df, ["period", "_chem_comp_tor.period", "periodicity"]),
errors="coerce",
)
return result
def _standardize_planes(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize plane restraint columns to: atom, plane_id, sigma."""
if df.empty:
return pd.DataFrame(columns=["atom", "plane_id", "sigma"])
result = pd.DataFrame()
result["atom"] = self._extract_col(
df, ["atom_id", "_chem_comp_plane_atom.atom_id", "atom"]
)
result["plane_id"] = self._extract_col(
df, ["plane_id", "_chem_comp_plane_atom.plane_id", "id"]
)
# Extract sigma (dist_esd) and convert to numeric
sigma = pd.to_numeric(
self._extract_col(
df, ["dist_esd", "_chem_comp_plane_atom.dist_esd", "sigma"]
),
errors="coerce",
)
# Fill missing values with 0.01 Å default, then clip minimum to 0.001 Å
# (avoid overly tight restraints while allowing looser ones)
sigma = sigma.fillna(0.01)
result["sigma"] = sigma.clip(lower=0.001) # Minimum 0.001 Å, no maximum
return result
def _standardize_chirals(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize chirality columns to: atom_centre, atom1, atom2, atom3, volume_sign."""
if df.empty:
return pd.DataFrame(
columns=["atom_centre", "atom1", "atom2", "atom3", "volume_sign"]
)
result = pd.DataFrame()
result["atom_centre"] = self._extract_col(
df, ["atom_id_centre", "_chem_comp_chir.atom_id_centre", "atom_centre"]
)
result["atom1"] = self._extract_col(
df, ["atom_id_1", "_chem_comp_chir.atom_id_1", "atom1"]
)
result["atom2"] = self._extract_col(
df, ["atom_id_2", "_chem_comp_chir.atom_id_2", "atom2"]
)
result["atom3"] = self._extract_col(
df, ["atom_id_3", "_chem_comp_chir.atom_id_3", "atom3"]
)
result["volume_sign"] = self._extract_col(
df, ["volume_sign", "_chem_comp_chir.volume_sign", "sign"]
)
return result
def _standardize_atoms(self, df: pd.DataFrame) -> pd.DataFrame:
"""Standardize atom definition columns to: atom_id, type_symbol, charge, etc."""
if df.empty:
return pd.DataFrame(columns=["atom_id", "type_symbol", "charge"])
result = pd.DataFrame()
result["atom_id"] = self._extract_col(
df, ["atom_id", "_chem_comp_atom.atom_id", "id"]
)
result["type_symbol"] = self._extract_col(
df, ["type_symbol", "_chem_comp_atom.type_symbol", "symbol"]
)
result["charge"] = pd.to_numeric(
self._extract_col(
df, ["charge", "_chem_comp_atom.charge", "partial_charge"]
),
errors="coerce",
)
# Include x,y,z if present (for ideal coordinates)
for coord in ["x", "y", "z"]:
coord_cols = [
f"pdbx_model_Cartn_{coord}_ideal",
f"_chem_comp_atom.pdbx_model_Cartn_{coord}_ideal",
f"_chem_comp_atom.{coord}",
coord,
]
if any(col in df.columns for col in coord_cols):
result[coord] = pd.to_numeric(
self._extract_col(df, coord_cols), errors="coerce"
)
return result
def _filter_by_comp(self, df: pd.DataFrame, comp_id: str) -> pd.DataFrame:
"""
Filter DataFrame to only rows matching the compound ID.
Args:
df: Source DataFrame
comp_id: Compound ID to filter for
Returns:
Filtered DataFrame
"""
if df.empty:
return df
# Try different possible column names for compound ID
# Include all naming conventions: monomer library, eLBOW/phenix, short forms
id_cols = [
"comp_id",
"_chem_comp.id",
"_chem_comp_bond.comp_id",
"_chem_comp_angle.comp_id",
"_chem_comp_tor.comp_id",
"_chem_comp_atom.comp_id",
"_chem_comp_plane_atom.comp_id",
"_chem_comp_chir.comp_id",
"id",
]
for col in id_cols:
if col in df.columns:
return df[df[col] == comp_id].copy()
# If no comp_id column, assume all rows belong to this compound
return df.copy()
[docs]
def get_bond_restraints(self, comp_id: str) -> pd.DataFrame:
"""
Get bond restraints with standardized column names.
Returns:
DataFrame with columns:
- atom1, atom2: Atom names
- value: Ideal bond length (Å)
- sigma: Estimated standard deviation (Å)
"""
restraints = self.get_compound_restraints(comp_id)
return restraints["bonds"]
def _extract_col(self, df: pd.DataFrame, possible_cols: List[str]) -> pd.Series:
"""Extract column trying multiple names."""
for col in possible_cols:
if col in df.columns:
return df[col]
return pd.Series([None] * len(df))
# Convenience methods for testing
[docs]
def get_compound_id(self) -> str:
"""Get the primary compound ID from this file."""
if self.compounds:
return self.compounds[0]
return self.filepath.stem
[docs]
def has_bond_restraints(self) -> bool:
"""Check if bond restraints are available."""
return "comp_bond" in self.cif.data or "chem_comp_bond" in self.cif.data
[docs]
def has_angle_restraints(self) -> bool:
"""Check if angle restraints are available."""
return "comp_angle" in self.cif.data or "chem_comp_angle" in self.cif.data
[docs]
def has_torsion_restraints(self) -> bool:
"""Check if torsion restraints are available."""
return "comp_tor" in self.cif.data or "chem_comp_tor" in self.cif.data
[docs]
def has_plane_restraints(self) -> bool:
"""Check if plane restraints are available."""
return (
"comp_plane_atom" in self.cif.data
or "chem_comp_plane_atom" in self.cif.data
)
[docs]
def has_chirality_restraints(self) -> bool:
"""Check if chirality definitions are available."""
return "comp_chir" in self.cif.data or "chem_comp_chir" in self.cif.data