Source code for torchref.io.ihm_mapping

"""
Intermediate representation for IHM mmCIF ensemble data.

Provides dataclasses that map between IHM categories and torchref's
ModelCollection / DatasetCollection structures. This mapping can be:

1. Populated by IHMReader from an IHM mmCIF file
2. Constructed programmatically (e.g., from a kinetic model)
3. Passed to IHMWriter for round-trip output

Concept Mapping
---------------
IHM state  ->  base model (ModelFT) in ModelCollection
IHM model group  ->  timepoint entry (_SharedMixedModel) in ModelCollection
IHM population fraction  ->  fraction_params in _SharedMixedModel
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional

import pandas as pd


[docs] @dataclass class IHMStateInfo: """ Metadata for a single structural state (e.g., ground state, intermediate). Parameters ---------- state_id : int Unique identifier matching ``_ihm_multi_state_modeling.state_id``. name : str Human-readable name (e.g., ``"ground_state"``, ``"intermediate_1"``). details : str Free-text description of this state. model_num : int ``pdbx_PDB_model_num`` in the ``_atom_site`` loop that corresponds to this state's coordinates. """ state_id: int name: str details: str = "" model_num: int = 1
[docs] @dataclass class IHMModelGroupInfo: """ Metadata for a model group (experimental condition / timepoint). Parameters ---------- group_id : int Unique identifier matching ``_ihm_model_group.id``. name : str Human-readable name (e.g., ``"dark"``, ``"1ps"``, ``"5ps"``). state_fractions : Dict[int, float] Mapping of ``state_id`` -> population fraction for this group. Fractions should sum to 1.0. time_delay : float, optional Time delay in ``time_delay_units`` (for time-resolved experiments). time_delay_units : str Units for ``time_delay``. Default ``"s"`` (seconds). """ group_id: int name: str state_fractions: Dict[int, float] = field(default_factory=dict) time_delay: Optional[float] = None time_delay_units: str = "s"
[docs] @dataclass class IHMEnsembleMapping: """ Complete mapping between IHM mmCIF categories and torchref structures. This is the central interchange object: both ``IHMReader`` and ``IHMWriter`` operate through it. It can also be constructed manually for programmatic workflows (e.g., building an IHM file from a ``KineticRefinement`` result without reading one first). Parameters ---------- states : List[IHMStateInfo] Structural states (one per base model in ModelCollection). model_groups : List[IHMModelGroupInfo] Model groups / timepoints (one per timepoint in ModelCollection). cell : list of float, optional Unit cell parameters ``[a, b, c, alpha, beta, gamma]``. spacegroup : str, optional Space group name (Hermann-Mauguin notation). atom_data_per_state : dict, optional Mapping of ``state_id`` -> pandas DataFrame with atom data. Populated by ``IHMReader.read_atom_data()``. """ states: List[IHMStateInfo] = field(default_factory=list) model_groups: List[IHMModelGroupInfo] = field(default_factory=list) cell: Optional[List[float]] = None spacegroup: Optional[str] = None atom_data_per_state: Optional[Dict[int, pd.DataFrame]] = None # ------------------------------------------------------------------ # Query helpers # ------------------------------------------------------------------
[docs] def get_state_ids(self) -> List[int]: """Return state IDs ordered by ``state_id``.""" return sorted(s.state_id for s in self.states)
[docs] def get_timepoint_names(self) -> List[str]: """Return model group names ordered by ``group_id``.""" return [g.name for g in sorted(self.model_groups, key=lambda g: g.group_id)]
[docs] def get_fractions_for_group(self, group_name: str) -> List[float]: """ Return population fractions for a model group, ordered by state_id. Parameters ---------- group_name : str Name of the model group. Returns ------- list of float Fractions ordered by ascending ``state_id``. Raises ------ KeyError If no group with the given name exists. """ for group in self.model_groups: if group.name == group_name: state_ids = self.get_state_ids() return [group.state_fractions.get(sid, 0.0) for sid in state_ids] raise KeyError(f"No model group named '{group_name}'")
[docs] def identify_dark_group(self) -> Optional[str]: """ Heuristic: identify the reference / dark group. Returns the name of the first model group where a single state has population fraction >= 0.95, or ``None`` if no such group exists. """ for group in sorted(self.model_groups, key=lambda g: g.group_id): fracs = list(group.state_fractions.values()) if fracs and max(fracs) >= 0.95: return group.name return None
[docs] def get_state_by_id(self, state_id: int) -> IHMStateInfo: """ Look up a state by its ID. Raises ------ KeyError If no state with the given ID exists. """ for state in self.states: if state.state_id == state_id: return state raise KeyError(f"No state with id {state_id}")
[docs] def get_group_by_name(self, name: str) -> IHMModelGroupInfo: """ Look up a model group by name. Raises ------ KeyError If no group with the given name exists. """ for group in self.model_groups: if group.name == name: return group raise KeyError(f"No model group named '{name}'")
# ------------------------------------------------------------------ # Validation # ------------------------------------------------------------------
[docs] def validate(self) -> None: """ Check internal consistency. Raises ------ ValueError If states are empty, fractions don't reference valid states, or fractions don't sum to ~1.0 for any group. """ if not self.states: raise ValueError("IHMEnsembleMapping has no states defined.") if not self.model_groups: raise ValueError("IHMEnsembleMapping has no model groups defined.") state_ids = set(s.state_id for s in self.states) for group in self.model_groups: # Check that all referenced state IDs exist for sid in group.state_fractions: if sid not in state_ids: raise ValueError( f"Model group '{group.name}' references state_id={sid} " f"which is not in the states list." ) # Check fractions sum to ~1.0 total = sum(group.state_fractions.values()) if abs(total - 1.0) > 0.05: raise ValueError( f"Model group '{group.name}' fractions sum to {total:.4f}, " f"expected ~1.0." )
# ------------------------------------------------------------------ # Repr # ------------------------------------------------------------------ def __repr__(self) -> str: state_names = [s.name for s in self.states] group_names = [g.name for g in self.model_groups] return ( f"IHMEnsembleMapping(" f"states={state_names}, " f"groups={group_names}, " f"spacegroup='{self.spacegroup}')" )