Source code for torchref.io.ihm

"""
IHM mmCIF reader and writer for kinetic ensemble data.

Provides ``IHMReader`` and ``IHMWriter`` for reading and writing IHM
(Integrative/Hybrid Methods) mmCIF files that describe multi-state
kinetic ensembles.

Requires the optional ``python-ihm`` dependency::

    pip install torchref[ihm]

Detection of IHM files (via ``is_ihm_file``) uses only gemmi and does
not require ``python-ihm``.
"""

from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import gemmi
import numpy as np
import pandas as pd

from torchref.io.ihm_mapping import IHMEnsembleMapping, IHMModelGroupInfo, IHMStateInfo

if TYPE_CHECKING:
    import torch

    from torchref.model.model_collection import ModelCollection
    from torchref.model.model_ft import ModelFT


def _check_ihm_available():
    """Raise ImportError with install instructions if python-ihm is missing."""
    try:
        import ihm  # noqa: F401
    except ImportError:
        raise ImportError(
            "python-ihm is required for IHM mmCIF support. "
            "Install with: pip install torchref[ihm]"
        )


[docs] class IHMReader: """ Read IHM mmCIF files into torchref ModelCollection + IHMEnsembleMapping. Uses ``python-ihm`` to parse IHM-specific categories (``_ihm_multi_state_modeling``, ``_ihm_model_group``, etc.) and gemmi / ``ModelCIFReader`` to parse standard ``_atom_site`` data. Parameters ---------- filepath : str Path to IHM mmCIF file. verbose : int Verbosity level (0=silent, 1=info, 2=debug). Examples -------- :: reader = IHMReader("ensemble.cif") model_collection, mapping = reader(max_res=1.5) # Or step by step: mapping = reader.read_mapping() mapping.atom_data_per_state = reader.read_atom_data(mapping) model_collection = reader.build_model_collection(mapping) """
[docs] def __init__(self, filepath: str, verbose: int = 0): _check_ihm_available() self.filepath = Path(filepath) self.verbose = verbose if not self.filepath.exists(): raise FileNotFoundError(f"IHM file not found: {self.filepath}")
# ------------------------------------------------------------------ # Static detection (gemmi only, no python-ihm needed) # ------------------------------------------------------------------
[docs] @staticmethod def is_ihm_file(filepath: str) -> bool: """ Quick check whether a CIF file contains IHM categories. Uses gemmi to avoid requiring ``python-ihm`` for detection. Looks for ``_ihm_model_list`` or ``_ihm_multi_state_modeling`` loops. Parameters ---------- filepath : str Path to a CIF/mmCIF file. Returns ------- bool """ try: doc = gemmi.cif.read(str(filepath)) except Exception: return False for block in doc: if block.find(["_ihm_model_list.model_id"]): return True if block.find(["_ihm_multi_state_modeling.state_id"]): return True return False
# ------------------------------------------------------------------ # Read IHM metadata -> IHMEnsembleMapping # ------------------------------------------------------------------
[docs] def read_mapping(self) -> IHMEnsembleMapping: """ Parse IHM categories and build an ``IHMEnsembleMapping``. Reads the following IHM categories: - ``_ihm_multi_state_modeling`` -> states - ``_ihm_model_list`` -> model enumeration - ``_ihm_model_group`` + ``_ihm_model_group_link`` -> groups - ``_ihm_multi_state_model_group_link`` -> state-group fractions Also extracts cell and spacegroup from standard mmCIF categories. Returns ------- IHMEnsembleMapping """ import ihm.reader with open(self.filepath) as f: systems = ihm.reader.read(f) if not systems: raise ValueError(f"No IHM system found in {self.filepath}") system = systems[0] # --- Extract named states (skip unnamed container states) --- states = [] state_groups = getattr(system, "state_groups", []) named_states = [] # (ihm.model.State, model_num) if state_groups: model_num = 1 for sg in state_groups: for state in sg: name = getattr(state, "name", None) stype = getattr(state, "type", None) # Skip unnamed states that are just model containers if name is None and stype is None: continue details = getattr(state, "details", "") or "" states.append( IHMStateInfo( state_id=model_num, name=name or f"state_{model_num}", details=details, model_num=model_num, ) ) named_states.append((state, model_num)) model_num += 1 # Fallback: if no named states, infer from models if not states: states = self._infer_states_from_models(system) # --- Extract model groups from all state groups --- model_groups = [] seen_ids = set() # deduplicate by Python object identity group_id = 1 if state_groups: for sg in state_groups: for state in sg: for mg in state: if id(mg) in seen_ids: continue seen_ids.add(id(mg)) mg_name = getattr(mg, "name", None) # Skip unnamed placeholder groups with no models if mg_name is None and len(mg) == 0: continue mg_name = mg_name or f"group_{group_id}" model_groups.append( IHMModelGroupInfo( group_id=group_id, name=mg_name, state_fractions={}, ) ) group_id += 1 # Fallback: build groups from model groups directly if not model_groups: model_groups = self._infer_groups_from_system(system, states) # --- Fill in fractions from _ihm_multi_state_model_group_link --- # Parse fractions directly from gemmi since python-ihm may not # fully expose them on the State objects self._fill_fractions_from_cif(states, model_groups) # --- Extract cell and spacegroup --- cell = None spacegroup = None doc = gemmi.cif.read(str(self.filepath)) for block in doc: cell = self._extract_cell_from_block(block) spacegroup = self._extract_spacegroup_from_block(block) if cell is not None: break mapping = IHMEnsembleMapping( states=states, model_groups=model_groups, cell=cell, spacegroup=spacegroup, ) if self.verbose > 0: print(f"IHM mapping: {len(states)} states, {len(model_groups)} groups") for s in states: print(f" State {s.state_id}: {s.name} (model_num={s.model_num})") for g in model_groups: print(f" Group {g.group_id}: {g.name} fractions={g.state_fractions}") return mapping
def _infer_states_from_models(self, system) -> List[IHMStateInfo]: """Infer states from ihm.System model groups when state_groups is empty.""" states = [] model_nums = set() for mg in getattr(system, "model_groups", []): for model in mg: num = getattr(model, "_id", len(states) + 1) if num not in model_nums: model_nums.add(num) name = getattr(model, "name", None) or f"state_{num}" states.append( IHMStateInfo( state_id=num, name=name, details="", model_num=num ) ) if not states: # Last resort: use atom_site model numbers states = self._infer_states_from_atom_site() return states def _infer_states_from_atom_site(self) -> List[IHMStateInfo]: """Infer states from pdbx_PDB_model_num in _atom_site.""" from torchref.io.cif_readers import ModelCIFReader reader = ModelCIFReader(str(self.filepath), verbose=0) by_model = reader.get_atom_data_by_model() return [ IHMStateInfo( state_id=num, name=f"state_{num}", details="", model_num=num, ) for num in sorted(by_model.keys()) ] def _infer_groups_from_system( self, system, states: List[IHMStateInfo] ) -> List[IHMModelGroupInfo]: """Build model groups with equal fractions when IHM categories are sparse.""" n_states = len(states) if n_states == 0: return [] equal_frac = 1.0 / n_states fracs = {s.state_id: equal_frac for s in states} return [ IHMModelGroupInfo( group_id=1, name="ensemble", state_fractions=fracs, ) ] def _fill_fractions_from_cif( self, states: List[IHMStateInfo], model_groups: List[IHMModelGroupInfo], ) -> None: """ Parse ``_ihm_multi_state_model_group_link`` via gemmi to fill population fractions in model groups. """ doc = gemmi.cif.read(str(self.filepath)) for block in doc: table = block.find( [ "_ihm_multi_state_model_group_link.state_id", "_ihm_multi_state_model_group_link.group_id", "_ihm_multi_state_model_group_link.population_fraction", ] ) if not table: continue # Build lookup: group_id -> IHMModelGroupInfo group_by_id = {g.group_id: g for g in model_groups} state_ids = [s.state_id for s in states] # Initialize all fractions to 0 for g in model_groups: g.state_fractions = {sid: 0.0 for sid in state_ids} for row in table: sid = int(row[0]) gid = int(row[1]) frac_str = row[2] frac = float(frac_str) if frac_str not in (".", "?") else 0.0 if gid in group_by_id and sid in state_ids: group_by_id[gid].state_fractions[sid] = frac return # No link table found — use equal fractions as fallback n_states = len(states) if n_states > 0: equal = 1.0 / n_states for g in model_groups: g.state_fractions = {s.state_id: equal for s in states} def _extract_cell_from_block(self, block) -> Optional[List[float]]: """Extract unit cell from a gemmi CIF block.""" try: a = block.find_value("_cell.length_a") b = block.find_value("_cell.length_b") c = block.find_value("_cell.length_c") alpha = block.find_value("_cell.angle_alpha") beta = block.find_value("_cell.angle_beta") gamma = block.find_value("_cell.angle_gamma") if a and b and c: return [ float(a), float(b), float(c), float(alpha) if alpha else 90.0, float(beta) if beta else 90.0, float(gamma) if gamma else 90.0, ] except (ValueError, TypeError): pass return None def _extract_spacegroup_from_block(self, block) -> Optional[str]: """Extract space group from a gemmi CIF block.""" for tag in [ "_symmetry.space_group_name_H-M", "_space_group.name_H-M_alt", ]: val = block.find_value(tag) if val and val not in ("?", "."): return val.strip("'\"") return None # ------------------------------------------------------------------ # Read atom data split by state # ------------------------------------------------------------------
[docs] def read_atom_data(self, mapping: IHMEnsembleMapping) -> Dict[int, pd.DataFrame]: """ Extract per-state atom DataFrames using ``pdbx_PDB_model_num``. Reuses ``ModelCIFReader`` for robust ``_atom_site`` parsing, then splits by model number and maps to state IDs. Parameters ---------- mapping : IHMEnsembleMapping Mapping with state info (specifically ``model_num`` per state). Returns ------- dict of int -> pandas.DataFrame Mapping of ``state_id`` -> atom DataFrame. """ from torchref.io.cif_readers import ModelCIFReader reader = ModelCIFReader(str(self.filepath), verbose=0) by_model = reader.get_atom_data_by_model() result = {} for state in mapping.states: if state.model_num not in by_model: raise ValueError( f"State '{state.name}' expects pdbx_PDB_model_num={state.model_num} " f"but file only contains model numbers: {sorted(by_model.keys())}" ) result[state.state_id] = by_model[state.model_num] # Validate atom consistency across states self._validate_atom_consistency(result, mapping) return result
def _validate_atom_consistency( self, atom_data: Dict[int, pd.DataFrame], mapping: IHMEnsembleMapping ) -> None: """Check that all states have the same atoms in the same order.""" state_ids = sorted(atom_data.keys()) if len(state_ids) < 2: return ref_id = state_ids[0] ref_df = atom_data[ref_id] ref_count = len(ref_df) for sid in state_ids[1:]: df = atom_data[sid] if len(df) != ref_count: ref_name = mapping.get_state_by_id(ref_id).name this_name = mapping.get_state_by_id(sid).name raise ValueError( f"Atom count mismatch: state '{ref_name}' has {ref_count} atoms " f"but state '{this_name}' has {len(df)} atoms. " f"All states must have identical atom lists for kinetic refinement." ) # ------------------------------------------------------------------ # Build ModelCollection from mapping + atom data # ------------------------------------------------------------------
[docs] def build_model_collection( self, mapping: IHMEnsembleMapping, max_res: float = 1.5, radius_angstrom: float = 4.0, device: "Optional[torch.device]" = None, ) -> "ModelCollection": """ Build a ``ModelCollection`` from parsed IHM data. For each state, creates a ``ModelFT`` and loads atom data. Then assembles a ``ModelCollection`` with one timepoint per model group. Parameters ---------- mapping : IHMEnsembleMapping Must have ``atom_data_per_state`` populated. max_res : float Maximum resolution for FFT grid setup. radius_angstrom : float Radius for electron density calculation. device : torch.device, optional Device for model tensors. Returns ------- ModelCollection """ import torch from torchref.model.model_collection import ModelCollection from torchref.model.model_ft import ModelFT if mapping.atom_data_per_state is None: raise ValueError( "mapping.atom_data_per_state is None. " "Call read_atom_data() first." ) if device is None: from torchref.config import get_default_device device = get_default_device() # Build one ModelFT per state base_models = [] for state in sorted(mapping.states, key=lambda s: s.state_id): df = mapping.atom_data_per_state[state.state_id] model = ModelFT( max_res=max_res, radius_angstrom=radius_angstrom, device=device, ) # Build a lightweight reader-like callable for Model.load() cell = mapping.cell spacegroup = mapping.spacegroup or "P 1" sg = gemmi.SpaceGroup(spacegroup) reader = _DataFrameReader(df, cell, sg) model.load(reader) if self.verbose > 0: print( f" Loaded state '{state.name}': " f"{len(df)} atoms" ) base_models.append(model) # Create ModelCollection collection = ModelCollection( base_models=base_models, verbose=self.verbose, ) # Add timepoints from model groups dark_name = mapping.identify_dark_group() for group in sorted(mapping.model_groups, key=lambda g: g.group_id): state_ids = mapping.get_state_ids() fractions = [ group.state_fractions.get(sid, 0.0) for sid in state_ids ] # Normalize fractions total = sum(fractions) if total > 0: fractions = [f / total for f in fractions] else: fractions = [1.0 / len(fractions)] * len(fractions) is_dark = (group.name == dark_name) if is_dark: collection.add_dark(fractions=fractions) else: collection.add_timepoint(group.name, fractions=fractions) return collection
# ------------------------------------------------------------------ # Convenience: read everything in one call # ------------------------------------------------------------------
[docs] def __call__( self, max_res: float = 1.5, radius_angstrom: float = 4.0, device: "Optional[torch.device]" = None, ) -> Tuple["ModelCollection", IHMEnsembleMapping]: """ Read IHM file and return ``(ModelCollection, IHMEnsembleMapping)``. Parameters ---------- max_res : float Maximum resolution for FFT grid. radius_angstrom : float Radius for electron density calculation. device : torch.device, optional Device for tensors. Returns ------- tuple of (ModelCollection, IHMEnsembleMapping) """ mapping = self.read_mapping() mapping.atom_data_per_state = self.read_atom_data(mapping) model_collection = self.build_model_collection( mapping, max_res=max_res, radius_angstrom=radius_angstrom, device=device, ) return model_collection, mapping
class _DataFrameReader: """ Minimal reader-like object for ``Model.load()``. Wraps a DataFrame + cell + spacegroup to match the protocol expected by ``Model.load(reader)`` where ``reader()`` returns ``(dataframe, cell, spacegroup)``. """ def __init__(self, df: pd.DataFrame, cell: Optional[List[float]], spacegroup): self.df = df self.cell = cell self.spacegroup = spacegroup def __call__(self): return self.df, self.cell, self.spacegroup
[docs] class IHMWriter: """ Write a torchref ``ModelCollection`` to IHM mmCIF format. Uses ``python-ihm`` to build a complete IHM System object and write it as a compliant mmCIF file. Parameters ---------- model_collection : ModelCollection The collection to write. mapping : IHMEnsembleMapping, optional Original mapping for round-tripping metadata. If ``None``, creates a minimal mapping from the collection structure. verbose : int Verbosity level. Examples -------- :: writer = IHMWriter(model_collection, mapping=mapping) writer.write("refined_ensemble.cif") # Or without a pre-existing mapping: writer = IHMWriter(model_collection) writer.write("refined_ensemble.cif") """
[docs] def __init__( self, model_collection: "ModelCollection", mapping: Optional[IHMEnsembleMapping] = None, datasets: Optional[Dict[str, "ReflectionData"]] = None, verbose: int = 0, ): _check_ihm_available() self.model_collection = model_collection self.mapping = mapping or self._create_default_mapping() self.datasets = datasets self.verbose = verbose
def _create_default_mapping(self) -> IHMEnsembleMapping: """ Create a minimal ``IHMEnsembleMapping`` from ``ModelCollection`` structure. Each base model becomes a state; each timepoint becomes a model group with current fractions. """ mc = self.model_collection # States from base models states = [] for i in range(mc.n_base_models): states.append( IHMStateInfo( state_id=i + 1, name=f"state_{i + 1}", details="", model_num=i + 1, ) ) # Model groups from timepoints state_ids = [s.state_id for s in states] model_groups = [] for group_id, (name, mixed) in enumerate(mc, start=1): fracs = mixed.fractions.detach().cpu().tolist() state_fractions = dict(zip(state_ids, fracs)) model_groups.append( IHMModelGroupInfo( group_id=group_id, name=name, state_fractions=state_fractions, ) ) # Cell and spacegroup from first base model cell = None spacegroup = None if mc.n_base_models > 0: model0 = mc.base_models[0] if hasattr(model0, "cell") and model0.cell is not None: cell_obj = model0.cell if hasattr(cell_obj, "parameters"): cell = cell_obj.parameters.tolist() elif hasattr(cell_obj, "tolist"): cell = cell_obj.tolist() if hasattr(model0, "spacegroup") and model0.spacegroup is not None: sg = model0.spacegroup if hasattr(sg, "hm"): spacegroup = sg.hm elif hasattr(sg, "xhm"): spacegroup = sg.xhm() else: spacegroup = str(sg) return IHMEnsembleMapping( states=states, model_groups=model_groups, cell=cell, spacegroup=spacegroup, )
[docs] def write(self, filepath: str) -> None: """ Write IHM mmCIF file. Builds a ``python-ihm`` System object with: - Entity and assembly from base model atom data - Multi-state definitions from mapping states - Model groups from mapping groups with population fractions - Atom coordinates per state via ``pdbx_PDB_model_num`` Parameters ---------- filepath : str Output file path. """ import ihm import ihm.dumper import ihm.model system = ihm.System(title="TorchRef kinetic ensemble") mc = self.model_collection mapping = self.mapping # --- Build entities and asym units from first base model --- import ihm.representation lpep = ihm.LPeptideAlphabet() asym_units = [] if mc.n_base_models > 0: model0 = mc.base_models[0] for chain_id, seq_str in model0.chain_sequences: seq = [] for char in seq_str: if char == "?": continue # skip gaps try: seq.append(lpep[char]) except KeyError: seq.append(lpep["UNK"]) if seq: entity = ihm.Entity(seq, description=f"Chain {chain_id}") system.entities.append(entity) asym = ihm.AsymUnit(entity, details=f"Chain {chain_id}") system.asym_units.append(asym) asym_units.append(asym) if not asym_units: entity = ihm.Entity( [lpep["UNK"]], description="Crystallographic model", ) system.entities.append(entity) asym = ihm.AsymUnit(entity) system.asym_units.append(asym) asym_units.append(asym) # --- Assembly and representation --- assembly = ihm.Assembly(asym_units, name="Complete assembly") system.orphan_assemblies.append(assembly) rep = ihm.representation.Representation( [ihm.representation.AtomicSegment(a, rigid=False) for a in asym_units] ) system.orphan_representations.append(rep) # --- Build states --- state_group = ihm.model.StateGroup() ihm_states = {} for state_info in sorted(mapping.states, key=lambda s: s.state_id): ihm_state = ihm.model.State(name=state_info.name) ihm_states[state_info.state_id] = ihm_state state_group.append(ihm_state) system.state_groups.append(state_group) # --- Build model groups and link to states --- for group_info in sorted(mapping.model_groups, key=lambda g: g.group_id): mg = ihm.model.ModelGroup(name=group_info.name) for state_id, fraction in group_info.state_fractions.items(): if state_id in ihm_states: ihm_state = ihm_states[state_id] ihm_state.append(mg) if fraction > 0: ihm_state.population_fraction = fraction # --- Write atom coordinates as a separate _atom_site block --- # python-ihm writes IHM categories; we append atom_site via gemmi filepath = Path(filepath) # First write the IHM system with open(filepath, "w") as f: ihm.dumper.write(f, [system]) # Then append multi-model atom_site data via gemmi self._append_atom_site(filepath, mapping) # Append per-timepoint reflection data blocks if provided if self.datasets: self._append_refln_blocks(filepath, mapping) if self.verbose > 0: print(f"Wrote IHM mmCIF: {filepath}") print(f" {len(mapping.states)} states, {len(mapping.model_groups)} groups") if self.datasets: print(f" {len(self.datasets)} reflection dataset(s)")
@classmethod def _from_mixed_model( cls, mixed_model, mapping: IHMEnsembleMapping, verbose: int = 0, ) -> "IHMWriter": """ Create an IHMWriter from a MixedModel (instead of ModelCollection). Wraps the MixedModel's constituent models in a lightweight adapter that provides the interface expected by the writer. Parameters ---------- mixed_model : MixedModel The mixed model to write. mapping : IHMEnsembleMapping Mapping with state/group metadata. verbose : int Verbosity level. Returns ------- IHMWriter """ adapter = _MixedModelAdapter(mixed_model) writer = cls.__new__(cls) _check_ihm_available() writer.model_collection = adapter writer.mapping = mapping writer.verbose = verbose return writer def _append_atom_site( self, filepath: Path, mapping: IHMEnsembleMapping ) -> None: """ Append ``_atom_site`` loop with ``pdbx_PDB_model_num`` to the CIF file. Iterates over base models, extracts current coordinates, and writes combined atom data with model numbers distinguishing states. """ mc = self.model_collection # Read existing CIF doc = gemmi.cif.read(str(filepath)) block = doc[0] if len(doc) > 0 else doc.add_new_block("torchref") # Add cell and spacegroup before atom_site so readers find them early if mapping.cell: a, b, c, alpha, beta, gamma = mapping.cell block.set_pair("_cell.length_a", f"{a:.4f}") block.set_pair("_cell.length_b", f"{b:.4f}") block.set_pair("_cell.length_c", f"{c:.4f}") block.set_pair("_cell.angle_alpha", f"{alpha:.4f}") block.set_pair("_cell.angle_beta", f"{beta:.4f}") block.set_pair("_cell.angle_gamma", f"{gamma:.4f}") if mapping.spacegroup: sg_val = mapping.spacegroup if " " in sg_val and not sg_val.startswith("'"): sg_val = f"'{sg_val}'" block.set_pair("_symmetry.space_group_name_H-M", sg_val) # Collect all atom data with model numbers all_rows = [] states_sorted = sorted(mapping.states, key=lambda s: s.state_id) for i, state in enumerate(states_sorted): if i >= mc.n_base_models: break model = mc.base_models[i] # Update PDB DataFrame with current refined coordinates if hasattr(model, "update_pdb"): model.update_pdb() pdb_df = model.pdb for _, row in pdb_df.iterrows(): atom_name = str(row.get("name", "CA")) altloc = str(row.get("altloc", ".")) or "." resname = str(row.get("resname", "UNK")) chainid = str(row.get("chainid", "A")) resseq = str(row.get("resseq", 1)) icode = str(row.get("icode", ".")) or "." all_rows.append( { "group_PDB": row.get("ATOM", "ATOM"), "id": str(row.get("serial", 0)), "type_symbol": str(row.get("element", "C")), "label_atom_id": atom_name, "label_alt_id": altloc, "label_comp_id": resname, "label_asym_id": chainid, "label_seq_id": resseq, "pdbx_PDB_ins_code": icode, "Cartn_x": f"{row.get('x', 0.0):.3f}", "Cartn_y": f"{row.get('y', 0.0):.3f}", "Cartn_z": f"{row.get('z', 0.0):.3f}", "occupancy": f"{row.get('occupancy', 1.0):.2f}", "B_iso_or_equiv": f"{row.get('tempfactor', 20.0):.2f}", "auth_seq_id": resseq, "auth_comp_id": resname, "auth_asym_id": chainid, "auth_atom_id": atom_name, "pdbx_PDB_model_num": str(state.model_num), } ) if not all_rows: return # Build the loop — include auth_* columns required by Coot/viewers tags = [ "_atom_site.group_PDB", "_atom_site.id", "_atom_site.type_symbol", "_atom_site.label_atom_id", "_atom_site.label_alt_id", "_atom_site.label_comp_id", "_atom_site.label_asym_id", "_atom_site.label_seq_id", "_atom_site.pdbx_PDB_ins_code", "_atom_site.Cartn_x", "_atom_site.Cartn_y", "_atom_site.Cartn_z", "_atom_site.occupancy", "_atom_site.B_iso_or_equiv", "_atom_site.auth_seq_id", "_atom_site.auth_comp_id", "_atom_site.auth_asym_id", "_atom_site.auth_atom_id", "_atom_site.pdbx_PDB_model_num", ] loop = block.init_loop("_atom_site.", [t.split(".")[-1] for t in tags]) for row in all_rows: loop.add_row( [ row["group_PDB"], row["id"], row["type_symbol"], row["label_atom_id"], row["label_alt_id"], row["label_comp_id"], row["label_asym_id"], row["label_seq_id"], row["pdbx_PDB_ins_code"], row["Cartn_x"], row["Cartn_y"], row["Cartn_z"], row["occupancy"], row["B_iso_or_equiv"], row["auth_seq_id"], row["auth_comp_id"], row["auth_asym_id"], row["auth_atom_id"], row["pdbx_PDB_model_num"], ] ) doc.write_file(str(filepath)) # ------------------------------------------------------------------ # Append per-timepoint reflection data # ------------------------------------------------------------------ def _append_refln_blocks( self, filepath: Path, mapping: IHMEnsembleMapping ) -> None: """ Append per-timepoint ``_refln`` data blocks to the CIF file. Each dataset gets its own CIF data block with cell, spacegroup, and a ``_refln`` loop containing HKL, F, σF, and R-free status. """ doc = gemmi.cif.read(str(filepath)) datasets = self.datasets # Support both dict and DatasetCollection (which is iterable) if hasattr(datasets, "items"): items = datasets.items() else: items = iter(datasets) for name, dataset in items: block = doc.add_new_block(name) # Cell and spacegroup per block if mapping.cell: a, b, c, alpha, beta, gamma = mapping.cell block.set_pair("_cell.length_a", f"{a:.4f}") block.set_pair("_cell.length_b", f"{b:.4f}") block.set_pair("_cell.length_c", f"{c:.4f}") block.set_pair("_cell.angle_alpha", f"{alpha:.4f}") block.set_pair("_cell.angle_beta", f"{beta:.4f}") block.set_pair("_cell.angle_gamma", f"{gamma:.4f}") if mapping.spacegroup: sg_val = mapping.spacegroup if " " in sg_val and not sg_val.startswith("'"): sg_val = f"'{sg_val}'" block.set_pair("_symmetry.space_group_name_H-M", sg_val) # Build _refln loop if dataset.hkl is None: continue hkl = dataset.hkl.detach().cpu().numpy() n_refln = len(hkl) has_F = dataset.F is not None has_sigF = dataset.F_sigma is not None has_I = dataset.I is not None has_sigI = dataset.I_sigma is not None has_rfree = dataset.rfree_flags is not None # Prepare numpy arrays for efficient access F_np = dataset.F.detach().cpu().numpy() if has_F else None sigF_np = dataset.F_sigma.detach().cpu().numpy() if has_sigF else None I_np = dataset.I.detach().cpu().numpy() if has_I else None sigI_np = dataset.I_sigma.detach().cpu().numpy() if has_sigI else None rfree_np = dataset.rfree_flags.detach().cpu().numpy() if has_rfree else None # Build tag list tags = ["index_h", "index_k", "index_l"] if has_F: tags.append("F_meas_au") if has_sigF: tags.append("F_meas_sigma_au") if has_I: tags.append("intensity_meas") if has_sigI: tags.append("intensity_sigma") if has_rfree: tags.append("status") loop = block.init_loop("_refln.", tags) for i in range(n_refln): row = [str(int(hkl[i, 0])), str(int(hkl[i, 1])), str(int(hkl[i, 2]))] if has_F: row.append(f"{F_np[i]:.4f}") if has_sigF: row.append(f"{sigF_np[i]:.4f}") if has_I: row.append(f"{I_np[i]:.4f}") if has_sigI: row.append(f"{sigI_np[i]:.4f}") if has_rfree: # CIF convention: 'f'=free, 'o'=working row.append("f" if int(rfree_np[i]) == 0 else "o") loop.add_row(row) if self.verbose > 0: print(f" Dataset '{name}': {n_refln} reflections") doc.write_file(str(filepath)) # ------------------------------------------------------------------ # Convenience for building mapping from ModelCollection + metadata # ------------------------------------------------------------------
[docs] @staticmethod def mapping_from_kinetic_refinement( model_collection: "ModelCollection", state_names: Optional[List[str]] = None, time_delays: Optional[Dict[str, float]] = None, ) -> IHMEnsembleMapping: """ Build an ``IHMEnsembleMapping`` from a ``ModelCollection`` with optional kinetic metadata. Parameters ---------- model_collection : ModelCollection The refined model collection. state_names : list of str, optional Names for each base model / state. Default: state_1, state_2, ... time_delays : dict, optional Mapping of timepoint name -> time delay in seconds. Returns ------- IHMEnsembleMapping """ mc = model_collection # States states = [] for i in range(mc.n_base_models): name = state_names[i] if state_names and i < len(state_names) else f"state_{i + 1}" states.append( IHMStateInfo( state_id=i + 1, name=name, model_num=i + 1, ) ) # Model groups from timepoints state_ids = [s.state_id for s in states] model_groups = [] for group_id, (name, mixed) in enumerate(mc, start=1): fracs = mixed.fractions.detach().cpu().tolist() state_fractions = dict(zip(state_ids, fracs)) delay = time_delays.get(name) if time_delays else None model_groups.append( IHMModelGroupInfo( group_id=group_id, name=name, state_fractions=state_fractions, time_delay=delay, ) ) return IHMEnsembleMapping( states=states, model_groups=model_groups, )
class _MixedModelAdapter: """ Adapter that makes a MixedModel look like a ModelCollection for IHMWriter. IHMWriter accesses ``model_collection.base_models`` and ``model_collection.n_base_models``. This adapter provides those properties from a MixedModel's ``models`` attribute. """ def __init__(self, mixed_model): self._mixed_model = mixed_model @property def base_models(self): return self._mixed_model.models @property def n_base_models(self): return len(self._mixed_model.models) def __iter__(self): """Yield (name, mixed_model) — single group.""" yield "ensemble", self._mixed_model def __len__(self): return 1