Source code for torchref.io.metadata
"""
Unified metadata dictionary for PDB deposition headers and mmCIF categories.
This module provides a ``RefinementMetadata`` dataclass that serves as the
single source of truth for all deposition-related information. The same
metadata can be rendered as PDB REMARK 3 header records or as PDBx/mmCIF
``_refine`` category fields.
Examples
--------
::
from torchref.io.metadata import RefinementMetadata
# Build from a completed refinement
meta = RefinementMetadata.from_refinement(refinement)
# Render for output
pdb_header = meta.render_pdb_header()
cif_dict = meta.render_cif_categories()
# Merge pass-through headers from input file
input_meta = RefinementMetadata.from_pdb_file("input.pdb")
meta = input_meta.merge(meta)
# Serialization
d = meta.to_dict()
meta2 = RefinementMetadata.from_dict(d)
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field, fields, asdict
from datetime import date
from typing import Any, Dict, List, Optional
[docs]
@dataclass
class RefinementMetadata:
"""Unified metadata for PDB headers and mmCIF categories.
Fields map to both PDB REMARK 3 lines and PDBx/mmCIF ``_refine``
category items. Only populated (non-None) fields are rendered.
Parameters
----------
program : str
Refinement program name.
program_version : str
Program version string.
"""
# Program identification
program: str = "TORCHREF"
program_version: str = ""
refinement_method: str = "" # e.g. "difference-refine", "LBFGS"
# Resolution
resolution_high: Optional[float] = None # d_min in Angstroms
resolution_low: Optional[float] = None # d_max in Angstroms
# Reflection counts
n_reflections_work: Optional[int] = None
n_reflections_test: Optional[int] = None
n_reflections_all: Optional[int] = None
percent_free: Optional[float] = None
# R-factors
r_work: Optional[float] = None
r_free: Optional[float] = None
# B-factor statistics
b_mean_overall: Optional[float] = None
b_min: Optional[float] = None
b_max: Optional[float] = None
# Geometry deviations
rmsd_bond_lengths: Optional[float] = None # Angstroms
rmsd_bond_angles: Optional[float] = None # degrees
# Model contents
n_atoms_total: Optional[int] = None
n_atoms_protein: Optional[int] = None
n_atoms_solvent: Optional[int] = None
# Solvent model
solvent_model_ksol: Optional[float] = None
solvent_model_bsol: Optional[float] = None
# Cell and spacegroup
cell: Optional[List[float]] = None # [a, b, c, alpha, beta, gamma]
spacegroup: Optional[str] = None
# Free-form fields
title: str = ""
authors: List[str] = field(default_factory=list)
# Pass-through: raw header lines from input file
passthrough_pdb_remarks: List[str] = field(default_factory=list)
passthrough_cif_categories: Dict[str, Any] = field(default_factory=dict)
# Custom remarks
custom_remarks: List[str] = field(default_factory=list)
# ------------------------------------------------------------------ #
# Serialization
# ------------------------------------------------------------------ #
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Serialize to a JSON-compatible dictionary, dropping None values."""
d = {}
for f in fields(self):
val = getattr(self, f.name)
if val is None:
continue
if isinstance(val, list) and len(val) == 0:
continue
if isinstance(val, dict) and len(val) == 0:
continue
if isinstance(val, str) and val == "" and f.name not in ("program",):
continue
d[f.name] = val
return d
[docs]
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> RefinementMetadata:
"""Reconstruct from a dictionary (inverse of ``to_dict``)."""
valid_fields = {f.name for f in fields(cls)}
filtered = {k: v for k, v in d.items() if k in valid_fields}
return cls(**filtered)
# ------------------------------------------------------------------ #
# Construction from refinement
# ------------------------------------------------------------------ #
[docs]
@classmethod
def from_refinement(cls, refinement) -> RefinementMetadata:
"""Extract metadata from a completed Refinement object.
Reuses existing statistics from ``collect_metrics()``,
``get_rfactor()``, and reflection data attributes.
Silently skips any unavailable statistics.
Parameters
----------
refinement : torchref.refinement.Refinement
A refinement object (after refinement is complete).
"""
import torch
from torchref import __version__
meta = cls(program_version=__version__)
# --- R-factors (from scaler, already computed) ---
try:
rwork, rfree = refinement.get_rfactor()
meta.r_work = float(rwork.item() if hasattr(rwork, "item") else rwork)
meta.r_free = float(rfree.item() if hasattr(rfree, "item") else rfree)
except Exception:
pass
# --- Resolution from reflection data ---
try:
rd = refinement.reflection_data
if rd.resolution is not None:
meta.resolution_high = float(rd.resolution.min())
meta.resolution_low = float(rd.resolution.max())
except Exception:
pass
# --- Reflection counts ---
try:
rd = refinement.reflection_data
with torch.no_grad():
hkl, fobs, sigma, rfree_flags = rd()
n_all = len(fobs)
n_test = int(rfree_flags.sum().item()) if rfree_flags.dtype == torch.bool else int((~rfree_flags.bool()).sum().item())
n_work = n_all - n_test
# In torchref, rfree_flags=True means WORK set
# Let's use the same convention as populate_state_meta
n_work = int(rfree_flags.sum().item())
n_test = n_all - n_work
meta.n_reflections_all = n_all
meta.n_reflections_work = n_work
meta.n_reflections_test = n_test
meta.percent_free = 100.0 * n_test / n_all if n_all > 0 else None
except Exception:
pass
# --- B-factor statistics from model ---
try:
model = refinement.model
model.update_pdb()
pdb = model.pdb
bvals = pdb["tempfactor"]
meta.b_mean_overall = float(bvals.mean())
meta.b_min = float(bvals.min())
meta.b_max = float(bvals.max())
except Exception:
pass
# --- Geometry deviations (silently skip if no restraints) ---
try:
model = refinement.model
if model.initialized and model._restraints is not None:
restraints = model.restraints
if hasattr(restraints, "bond_deviations"):
with torch.no_grad():
bond_devs, _ = restraints.bond_deviations()
meta.rmsd_bond_lengths = float(
torch.sqrt((bond_devs**2).mean())
)
if hasattr(restraints, "angle_deviations"):
with torch.no_grad():
angle_devs, _ = restraints.angle_deviations()
meta.rmsd_bond_angles = float(
torch.sqrt((angle_devs**2).mean())
)
except Exception:
pass
# --- Atom counts ---
try:
pdb = refinement.model.pdb
meta.n_atoms_total = len(pdb)
protein_mask = pdb["ATOM"] == "ATOM"
meta.n_atoms_protein = int(protein_mask.sum())
solvent_mask = pdb["ATOM"] == "HETATM"
meta.n_atoms_solvent = int(solvent_mask.sum())
except Exception:
pass
# --- Solvent model parameters ---
try:
scaler = refinement.scaler
if hasattr(scaler, "solvent_model") and scaler.solvent_model is not None:
sm = scaler.solvent_model
meta.solvent_model_ksol = float(
torch.exp(sm.log_k_solvent).item()
)
meta.solvent_model_bsol = float(sm.b_solvent.item())
except Exception:
pass
# --- Cell and spacegroup ---
try:
model = refinement.model
if model.cell is not None:
meta.cell = [float(x) for x in model.cell.data.tolist()]
if model.spacegroup is not None:
meta.spacegroup = model.spacegroup.hm
except Exception:
pass
return meta
# ------------------------------------------------------------------ #
# Construction from input files (pass-through)
# ------------------------------------------------------------------ #
[docs]
@classmethod
def from_pdb_file(cls, filepath: str) -> RefinementMetadata:
"""Extract header metadata from an existing PDB file.
Captures TITLE, AUTHOR, and REMARK records for pass-through.
"""
meta = cls()
remarks = []
try:
with open(filepath, "r") as f:
for line in f:
record = line[:6].strip()
if record in ("ATOM", "HETATM"):
break
if record == "TITLE":
title_text = line[10:].strip()
if meta.title:
meta.title += " " + title_text
else:
meta.title = title_text
elif record == "AUTHOR":
author_text = line[10:].strip()
# Authors are comma-separated in PDB
for author in author_text.split(","):
author = author.strip()
if author:
meta.authors.append(author)
elif record.startswith("REMARK"):
remarks.append(line.rstrip("\n"))
meta.passthrough_pdb_remarks = remarks
except Exception:
pass
return meta
[docs]
@classmethod
def from_cif_file(cls, filepath: str) -> RefinementMetadata:
"""Extract refinement metadata from an existing mmCIF file.
Captures ``_struct.title``, ``_audit_author.name``, and
``_refine`` category items for pass-through.
"""
meta = cls()
try:
import gemmi
doc = gemmi.cif.read(filepath)
block = doc[0]
# Title
title = block.find_value("_struct.title")
if title and title != "?":
meta.title = gemmi.cif.as_string(title)
# Authors
author_loop = block.find(["_audit_author.name"])
if author_loop:
for row in author_loop:
name = gemmi.cif.as_string(row[0])
if name and name != "?":
meta.authors.append(name)
# _refine category pass-through
refine_cats = {}
for tag in block.find(["_refine."]):
# Collect all _refine.* pairs
pass
# Use find_values for individual items
for item_name in [
"_refine.ls_R_factor_R_work",
"_refine.ls_R_factor_R_free",
"_refine.ls_d_res_high",
"_refine.ls_d_res_low",
"_refine.ls_number_reflns_R_work",
"_refine.ls_number_reflns_R_free",
"_refine.B_iso_mean",
]:
val = block.find_value(item_name)
if val and val not in ("?", "."):
refine_cats[item_name] = gemmi.cif.as_string(val)
if refine_cats:
meta.passthrough_cif_categories["_refine"] = refine_cats
except Exception:
pass
return meta
# ------------------------------------------------------------------ #
# Merge
# ------------------------------------------------------------------ #
[docs]
def merge(self, other: RefinementMetadata) -> RefinementMetadata:
"""Merge *other* into self. Non-None values in *other* take precedence.
Pass-through containers are combined (not replaced).
Parameters
----------
other : RefinementMetadata
Metadata to merge in (takes precedence for non-None fields).
Returns
-------
RefinementMetadata
A new merged instance.
"""
merged = RefinementMetadata()
for f in fields(RefinementMetadata):
self_val = getattr(self, f.name)
other_val = getattr(other, f.name)
# For pass-through containers, combine
if f.name == "passthrough_pdb_remarks":
merged_list = list(self_val) + [
r for r in other_val if r not in self_val
]
setattr(merged, f.name, merged_list)
elif f.name == "passthrough_cif_categories":
merged_dict = dict(self_val)
merged_dict.update(other_val)
setattr(merged, f.name, merged_dict)
elif f.name == "authors":
merged_authors = list(self_val) + [
a for a in other_val if a not in self_val
]
setattr(merged, f.name, merged_authors)
elif f.name == "custom_remarks":
merged_remarks = list(self_val) + list(other_val)
setattr(merged, f.name, merged_remarks)
else:
# other takes precedence if non-None and non-default
if other_val is not None and other_val != "" and other_val != []:
setattr(merged, f.name, other_val)
else:
setattr(merged, f.name, self_val)
return merged
# ------------------------------------------------------------------ #
# PDB header rendering
# ------------------------------------------------------------------ #
[docs]
def render_pdb_header(self) -> str:
"""Render metadata as PDB header records (REMARK 3, TITLE, AUTHOR).
Returns
-------
str
Multi-line string ready to insert into a PDB file.
"""
lines: List[str] = []
# Pass-through remarks first (from input file)
for remark in self.passthrough_pdb_remarks:
lines.append(remark)
# TITLE
if self.title:
_wrap_pdb_record(lines, "TITLE", self.title)
# AUTHOR
if self.authors:
author_str = ", ".join(self.authors)
_wrap_pdb_record(lines, "AUTHOR", author_str)
# REMARK 3 - Refinement statistics
lines.append("REMARK 3")
lines.append("REMARK 3 REFINEMENT.")
lines.append(
f"REMARK 3 PROGRAM : {self.program} {self.program_version}".rstrip()
)
if self.refinement_method:
lines.append(
f"REMARK 3 METHOD : {self.refinement_method}"
)
lines.append("REMARK 3")
# Data used in refinement
lines.append("REMARK 3 DATA USED IN REFINEMENT.")
_remark3(lines, "RESOLUTION RANGE HIGH (ANGSTROMS)", self.resolution_high, ".2f")
_remark3(lines, "RESOLUTION RANGE LOW (ANGSTROMS)", self.resolution_low, ".2f")
_remark3(lines, "NUMBER OF REFLECTIONS", self.n_reflections_all, "d")
lines.append("REMARK 3")
# Fit to data
lines.append("REMARK 3 FIT TO DATA USED IN REFINEMENT.")
_remark3(lines, "R VALUE (WORKING SET)", self.r_work, ".4f")
_remark3(lines, "FREE R VALUE", self.r_free, ".4f")
_remark3(lines, "FREE R VALUE TEST SET SIZE (%)", self.percent_free, ".1f")
_remark3(lines, "FREE R VALUE TEST SET COUNT", self.n_reflections_test, "d")
lines.append("REMARK 3")
# B-values
lines.append("REMARK 3 B VALUES.")
_remark3(lines, "FROM WILSON PLOT (A**2)", None, ".2f")
_remark3(lines, "MEAN B VALUE (OVERALL, A**2)", self.b_mean_overall, ".2f")
_remark3(lines, "B MIN (A**2)", self.b_min, ".2f")
_remark3(lines, "B MAX (A**2)", self.b_max, ".2f")
lines.append("REMARK 3")
# RMS deviations
lines.append("REMARK 3 RMS DEVIATIONS FROM IDEAL VALUES.")
_remark3(lines, "BOND LENGTHS (A)", self.rmsd_bond_lengths, ".3f")
_remark3(lines, "BOND ANGLES (DEGREES)", self.rmsd_bond_angles, ".2f")
lines.append("REMARK 3")
# Model contents
lines.append("REMARK 3 NUMBER OF NON-HYDROGEN ATOMS USED IN REFINEMENT.")
_remark3(lines, "PROTEIN ATOMS", self.n_atoms_protein, "d")
_remark3(lines, "SOLVENT ATOMS", self.n_atoms_solvent, "d")
_remark3(lines, "TOTAL", self.n_atoms_total, "d")
lines.append("REMARK 3")
# Solvent model
if self.solvent_model_ksol is not None or self.solvent_model_bsol is not None:
lines.append("REMARK 3 BULK SOLVENT MODELLING.")
_remark3(lines, "K_SOL", self.solvent_model_ksol, ".4f")
_remark3(lines, "B_SOL", self.solvent_model_bsol, ".2f")
lines.append("REMARK 3")
# Custom remarks
for remark in self.custom_remarks:
lines.append(f"REMARK 3 {remark}")
lines.append("REMARK 3")
return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ #
# mmCIF rendering
# ------------------------------------------------------------------ #
[docs]
def render_cif_categories(self) -> Dict[str, Dict[str, str]]:
"""Render metadata as mmCIF category dictionaries.
Returns a dict of dicts keyed by mmCIF category, with item names
as keys and string values. Uses official PDBx/mmCIF field names.
Returns
-------
dict
Nested dictionary ``{category: {field: value}}``.
"""
cats: Dict[str, Dict[str, str]] = {}
# _software
sw = {}
sw["_software.name"] = self.program
if self.program_version:
sw["_software.version"] = self.program_version
sw["_software.classification"] = "refinement"
if self.refinement_method:
sw["_software.description"] = self.refinement_method
sw["_software.pdbx_ordinal"] = "1"
cats["_software"] = sw
# _struct
if self.title:
cats["_struct"] = {"_struct.title": self.title}
# _audit_author (loop)
if self.authors:
cats["_audit_author"] = {
"_audit_author.name": self.authors,
}
# _refine
ref = {}
if self.r_work is not None:
ref["_refine.ls_R_factor_R_work"] = f"{self.r_work:.4f}"
if self.r_free is not None:
ref["_refine.ls_R_factor_R_free"] = f"{self.r_free:.4f}"
if self.resolution_high is not None:
ref["_refine.ls_d_res_high"] = f"{self.resolution_high:.2f}"
if self.resolution_low is not None:
ref["_refine.ls_d_res_low"] = f"{self.resolution_low:.2f}"
if self.n_reflections_all is not None:
ref["_refine.ls_number_reflns_all"] = str(self.n_reflections_all)
if self.n_reflections_work is not None:
ref["_refine.ls_number_reflns_R_work"] = str(self.n_reflections_work)
if self.n_reflections_test is not None:
ref["_refine.ls_number_reflns_R_free"] = str(self.n_reflections_test)
if self.percent_free is not None:
ref["_refine.ls_percent_reflns_R_free"] = f"{self.percent_free:.1f}"
if self.b_mean_overall is not None:
ref["_refine.B_iso_mean"] = f"{self.b_mean_overall:.2f}"
if self.b_min is not None:
ref["_refine.B_iso_min"] = f"{self.b_min:.2f}"
if self.b_max is not None:
ref["_refine.B_iso_max"] = f"{self.b_max:.2f}"
if self.solvent_model_ksol is not None:
ref["_refine.solvent_model_param_ksol"] = f"{self.solvent_model_ksol:.4f}"
if self.solvent_model_bsol is not None:
ref["_refine.solvent_model_param_bsol"] = f"{self.solvent_model_bsol:.2f}"
if ref:
cats["_refine"] = ref
# _refine_ls_restr (geometry deviations, as loop)
if self.rmsd_bond_lengths is not None or self.rmsd_bond_angles is not None:
restr_types = []
restr_devs = []
if self.rmsd_bond_lengths is not None:
restr_types.append("f_bond_d")
restr_devs.append(f"{self.rmsd_bond_lengths:.4f}")
if self.rmsd_bond_angles is not None:
restr_types.append("f_angle_d")
restr_devs.append(f"{self.rmsd_bond_angles:.4f}")
cats["_refine_ls_restr"] = {
"_refine_ls_restr.type": restr_types,
"_refine_ls_restr.dev_ideal": restr_devs,
}
# _refine_hist (atom counts)
if self.n_atoms_total is not None:
hist = {}
hist["_refine_hist.number_atoms_total"] = str(self.n_atoms_total)
if self.n_atoms_protein is not None:
hist["_refine_hist.number_atoms_protein"] = str(self.n_atoms_protein)
if self.n_atoms_solvent is not None:
hist["_refine_hist.number_atoms_solvent"] = str(self.n_atoms_solvent)
cats["_refine_hist"] = hist
# Pass-through CIF categories
for cat_name, items in self.passthrough_cif_categories.items():
if cat_name not in cats:
cats[cat_name] = items
return cats
# ====================================================================== #
# Private helpers
# ====================================================================== #
def _remark3(
lines: List[str], label: str, value: Any, fmt: str = ""
) -> None:
"""Append a REMARK 3 line with label : value formatting."""
if value is not None:
formatted = f"{value:{fmt}}"
else:
formatted = "NULL"
line = f"REMARK 3 {label} : {formatted}"
lines.append(line)
def _wrap_pdb_record(lines: List[str], record: str, text: str) -> None:
"""Wrap long text into multiple PDB records (80-char lines).
Continuation lines use record + spaces + continuation number.
"""
max_text = 80 - 10 # 10 chars for record + spaces
words = text.split()
current = ""
continuation = 0
for word in words:
if current and len(current) + 1 + len(word) > max_text:
if continuation == 0:
lines.append(f"{record:<10}{current}")
else:
lines.append(f"{record:<8}{continuation + 1:>2}{current}")
current = word
continuation += 1
else:
current = current + " " + word if current else word
if current:
if continuation == 0:
lines.append(f"{record:<10}{current}")
else:
lines.append(f"{record:<8}{continuation + 1:>2}{current}")