Source code for torchref.refinement.targets.amber_target

"""
AMBER14/GAFF2 Force Field as a Differentiable Restraint.

Uses OpenMM to evaluate the AMBER14 energy for current model coordinates.
Analytical forces from OpenMM are bridged into PyTorch autograd via a
custom Function, making the energy fully differentiable w.r.t. xyz.

Non-standard residues (HETATM not in AMBER14_STANDARD) are parameterised
automatically via antechamber/GAFF2.  Results are cached under
``PATH_TORCHREF_DATA / "amber_cache" / {resname}/``.

Intended workflow::

    # Canonical one-liner — strips altlocs, adds H, then build target:
    mh = (Model(verbose=0, strip_H=True)
          .load_pdb('structure.pdb')
          .strip_altlocs()
          .generate_hydrogens())
    target = AmberTarget(model=mh)                               # protein-only
    target = AmberTarget(model=mh, residue_charges={'LIG': -1})  # with ligand

    loss = target()          # kJ/mol per atom
    loss.backward()
    # xyz gradient is now populated with AMBER forces

Performance note
----------------
OpenMM's ``Modeller.addHydrogens()`` is 4–5× faster when H atoms are already
present in the model (it refines positions rather than building from scratch).
For pure-protein structures: init ~3 s with H, ~11 s from heavy atoms only.
Gradient and energy are identical either way (H are stripped from the atom map;
``n_model_atoms`` changes only the energy normalisation).

Design notes
------------
- No pdbfixer dependency.  H atoms are handled by OpenMM's Modeller
  (standard-residues path) or tleap (GAFF2 path).
- Altloc atoms are filtered before building the OpenMM system: only the
  primary conformation (altloc == '' or 'A') is used.
- OXT and H atoms are excluded from the PDB written to tleap; tleap
  re-adds them via its C-terminal and H-addition templates.
- H positions in the OpenMM context are set once at construction and are
  NOT updated during forward() — a good approximation for small refinement
  steps (< 0.1 Å heavy-atom displacement).
- model_to_omm maps model-atom index → OpenMM atom index for HEAVY atoms
  only.  Model H atoms receive -1 and are skipped in forward().
"""

from __future__ import annotations

import hashlib
import json
import os
import shutil
import subprocess
import tempfile
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import numpy as np
import torch

from torchref import PATH_TORCHREF_DATA
from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import ModelTarget

if TYPE_CHECKING:
    from torchref.model.model import Model


# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

# AmberTools binaries — discovered lazily on first GAFF2 use so that
# importing torchref never fails even when ambertools is absent.
_AMBERTOOLS_BINARIES: Dict[str, Optional[str]] = {}


def _find_ambertools_binary(name: str) -> str:
    """Locate an AmberTools binary on PATH or via $AMBERHOME.

    Raises FileNotFoundError with install instructions when not found.
    """
    if name in _AMBERTOOLS_BINARIES:
        cached = _AMBERTOOLS_BINARIES[name]
        if cached is not None:
            return cached

    path = shutil.which(name)
    if path:
        _AMBERTOOLS_BINARIES[name] = path
        return path

    for env_var in ("AMBERHOME", "AMBERTOOLS_HOME"):
        home = os.environ.get(env_var)
        if home:
            candidate = os.path.join(home, "bin", name)
            if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
                _AMBERTOOLS_BINARIES[name] = candidate
                return candidate

    raise FileNotFoundError(
        f"'{name}' not found on PATH or in $AMBERHOME.\n"
        f"Install AmberTools:  conda install -c conda-forge ambertools\n"
        f"Or provide pre-computed mol2/frcmod files via gaff2_files=."
    )

#: Residue names covered by AMBER14 force field — antechamber not needed.
AMBER14_STANDARD: frozenset = frozenset(
    {
        # Protein residues
        "ALA", "ARG", "ASN", "ASP", "CYS", "CYX", "GLN", "GLU", "GLY",
        "HID", "HIE", "HIP", "HIS", "ILE", "LEU", "LYS", "MET", "PHE",
        "PRO", "SER", "THR", "TRP", "TYR", "VAL",
        # Terminal caps
        "ACE", "NME",
        # Water and common ions
        "HOH", "WAT", "NA", "K", "CL", "MG", "ZN", "CA", "FE", "MN",
        # RNA / DNA nucleotides
        "A", "G", "C", "U", "T", "DA", "DG", "DC", "DT",
    }
)

# Atom names that tleap adds itself via terminal / template logic; must be
# excluded from the PDB handed to tleap to avoid "does not have a type" errors.
_TLEAP_SKIP_ATOMS: frozenset = frozenset({"OXT", "OT1", "OT2"})

# Residues handled by amber14-all.xml + amber14/tip3pfb.xml in OpenMM Modeller
# (does NOT include Mg/Zn/Ca/Fe etc. — those lack templates in the default XML set)
_MODELLER_FF_RESIDUES: frozenset = frozenset(
    {
        "ALA", "ARG", "ASN", "ASP", "CYS", "CYX", "GLN", "GLU", "GLY",
        "HID", "HIE", "HIP", "HIS", "ILE", "LEU", "LYS", "MET", "PHE",
        "PRO", "SER", "THR", "TRP", "TYR", "VAL",
        "ACE", "NME",
        "HOH", "WAT",
        "NA", "K", "CL",   # ions in amber14/tip3pfb.xml
        "A", "G", "C", "U", "T", "DA", "DG", "DC", "DT",
    }
)

# Residues to exclude from the protein PDB written to tleap (GAFF2 path).
# Currently empty: all AMBER14_STANDARD residues (protein, ions, water) are
# included so they participate in both LJ (steric) and Coulomb gradients.
#
# Waters ARE included because:
# - Crystal waters are poorly restrained by X-ray data (weak density, high B)
#   so their AMBER LJ/Coulomb gradient is their primary positional restraint.
# - tleap reads the PDB sequentially and preserves HOH order, so the
#   sequential residue map still matches correctly.
# - The Coulomb magnitude is wrong (no dielectric screening) but the direction
#   is correct; the AMBER weight in the total loss absorbs the scale error.
#
# Monatomic ions (MG, ZN, CA, …) are covered by leaprc.water.tip3p
# (Li/Merz 12-6 + Joung-Cheatham sets) and are critical for electrostatics
# near charged ligands.
_TLEAP_EXCLUDE_RESIDUES: frozenset = frozenset()


# ---------------------------------------------------------------------------
# Autograd bridge
# ---------------------------------------------------------------------------


class _OpenMMAMBERFunction(torch.autograd.Function):
    """
    Bridges OpenMM energy + analytical forces into PyTorch autograd.

    forward : xyz_ang (Å, float, [n_model, 3]) → energy (kJ/mol, scalar)
    backward: ∂loss/∂xyz = −F (OpenMM forces, exact gradients)

    Non-tensor arguments are passed as plain Python objects; they receive
    ``None`` gradients and are not differentiated.
    """

    @staticmethod
    def forward(ctx, xyz_ang, context, model_to_omm, pos_buf):
        import openmm.unit as unit  # noqa: PLC0415

        # Update heavy-atom positions in the pre-allocated nm buffer (Å → nm)
        model_xyz_nm = xyz_ang.detach().cpu().numpy().astype(np.float64) * 0.1
        valid = model_to_omm >= 0
        pos_buf[model_to_omm[valid]] = model_xyz_nm[valid]

        # Transfer positions to OpenMM context (CPU → GPU inside OpenMM)
        context.setPositions(pos_buf)

        state = context.getState(getEnergy=True, getForces=True)
        energy_kJ = state.getPotentialEnergy().value_in_unit(
            unit.kilojoules_per_mole
        )
        forces_kJ_nm = state.getForces(asNumpy=True).value_in_unit(
            unit.kilojoules_per_mole / unit.nanometer
        )

        # Map forces: OpenMM-indexed → model-indexed;  kJ/mol/nm → kJ/mol/Å
        n_model = xyz_ang.shape[0]
        model_forces = np.zeros((n_model, 3), dtype=np.float64)
        model_forces[valid] = forces_kJ_nm[model_to_omm[valid]] * 0.1

        # Clamp per-atom force magnitude to prevent extreme LJ clashes
        # from producing gradients that blow up the optimizer.
        # 1000 kJ/mol/Å ≈ force from a ~0.3 Å LJ overlap.
        max_force = 1000.0  # kJ/mol/Å
        norms = np.linalg.norm(model_forces, axis=1, keepdims=True)
        norms = np.maximum(norms, 1e-10)
        scale = np.minimum(max_force / norms, 1.0)
        model_forces *= scale

        ctx.save_for_backward(
            torch.tensor(
                model_forces, dtype=xyz_ang.dtype, device=xyz_ang.device
            )
        )
        return torch.tensor(energy_kJ, dtype=xyz_ang.dtype, device=xyz_ang.device)

    @staticmethod
    def backward(ctx, grad_output):
        (forces,) = ctx.saved_tensors
        # F = −∂E/∂x  →  ∂E/∂x = −F
        return -forces * grad_output, None, None, None


# ---------------------------------------------------------------------------
# AmberTarget
# ---------------------------------------------------------------------------


[docs] class AmberTarget(ModelTarget): """ Differentiable AMBER14/GAFF2 force-field energy restraint. On construction the target: 1. Detects non-standard residues (HETATM not in :data:`AMBER14_STANDARD`). 2. Runs antechamber + parmchk2 (parallel, cached) for each non-standard residue. 3. Builds an OpenMM system: * **Standard path** (no non-standard residues): filter model PDB to primary conformation + heavy atoms, use ``openmm.app.Modeller`` to re-add H with AMBER14-compatible names, create system with ``ForceField('amber14-all.xml')``. * **GAFF2 path** (with non-standard residues): same protein PDB (additionally removing OXT) handed to tleap together with each ligand's mol2 via ``combine{}``. Combined AMBER14+GAFF2 topology is parameterised by parmed. 4. Creates an OpenMM Context on the platform that matches the model's device: CUDA for ``model.device.type == 'cuda'``, CPU otherwise. Falls back CUDA → OpenCL → CPU if the preferred platform is unavailable. 5. Builds a model-atom → OpenMM-atom index map so that only heavy atoms are transferred; H positions are kept from the initial OpenMM setup. Parameters ---------- model : Model TorchRef model. Heavy-atom-only models (``strip_H=True``) are accepted. H atoms are added internally by OpenMM's Modeller or tleap and are NOT included in the atom map or gradient. Passing a model that already has H atoms (via ``model.generate_hydrogens()`` or loading a PDB with H) speeds up initialisation ~4× because ``Modeller.addHydrogens()`` converges faster from existing positions. **Required for GAFF2 ligands**: antechamber's BCC charge scheme runs a semiempirical QM step (sqm) that requires a fully protonated molecule. If the model has no H atoms for a non-standard residue, an explicit error is raised. Call ``model.generate_hydrogens()`` or load the PDB with ``strip_H=False`` before creating the target. cutoff : float Non-bonded cutoff in Angstroms. Default 5.0. normalize_by_atoms : bool If True the energy is divided by the number of model atoms. Default True. residue_charges : dict[str, int], optional Net formal charge per non-standard residue name, e.g. ``{'LIG': -1, 'ATP': -4}``. Residues not listed default to 0 with a warning. verbose : int Verbosity level (0 = silent, 1 = informational, 2 = debug). """ name: str = "amber"
[docs] def __init__( self, model: "Model" = None, cutoff: float = 5.0, normalize_by_atoms: bool = True, residue_charges: Optional[Dict[str, int]] = None, gaff2_files: Optional[Dict[str, Tuple[str, str]]] = None, verbose: int = 0, ): try: import openmm # noqa: F401, PLC0415 except ImportError: raise ImportError( "AmberTarget requires OpenMM.\n" "Install with: pip install torchref[amber]\n" "Or via conda: conda install -c conda-forge openmm" ) from None super().__init__(model=model, verbose=verbose) self._normalize = normalize_by_atoms self._residue_charges = dict(residue_charges) if residue_charges else {} self._gaff2_files = dict(gaff2_files) if gaff2_files else {} self.register_buffer("_cutoff_buf", torch.tensor(float(cutoff))) # Internal state (None until fully initialised) self._context = None self._platform_name: str = "none" self._model_to_omm: Optional[np.ndarray] = None self._pos_buf: Optional[np.ndarray] = None self._n_omm_atoms: int = 0 self._n_model_atoms: int = 0 self._n_nonstandard: int = 0 # GAFF2 path: ordered residue map for atom matching (None = standard path) self._tleap_residue_map: Optional[List[Dict[str, int]]] = None if model is None: return # Allow empty init for state_dict loading self._build(model)
# ------------------------------------------------------------------ # Top-level build orchestration # ------------------------------------------------------------------ def _build(self, model: "Model") -> None: """Detect → antechamber → build OpenMM system → map atoms.""" # Reject models with alternate conformations — OpenMM only handles # a single conformation. Call model.strip_altlocs() first. altlocs = model.pdb["altloc"].astype(str).str.strip() if (altlocs != "").any(): raise ValueError( "[AmberTarget] Model contains alternate conformations. " "OpenMM requires a single conformation.\n" "Fix: model = model.strip_altlocs() before creating AmberTarget." ) nonstandard = self._detect_nonstandard_residues() self._n_nonstandard = len(nonstandard) gaff2_params = self._run_antechamber_parallel(nonstandard) system, topology, positions_nm = self._build_omm_system(gaff2_params) self._system = system self._topology = topology # Make tleap positions available to _build_atom_map (GAFF2 path uses # position-based matching; positions_nm will be cleaned up afterward). self._tleap_pos_nm = positions_nm self._build_atom_map() del self._tleap_pos_nm self._build_context(positions_nm) # Pre-allocate nm position buffer: H positions pre-filled from OpenMM init self._pos_buf = positions_nm.copy() self._n_model_atoms = len(model.pdb) if self.verbose >= 1: print( f"[AmberTarget] platform={self._platform_name}, " f"n_omm={self._n_omm_atoms}, n_model={self._n_model_atoms}, " f"n_nonstandard={self._n_nonstandard}" ) # ------------------------------------------------------------------ # Step 1 — Detect non-standard residues # ------------------------------------------------------------------ def _detect_nonstandard_residues(self) -> List[Tuple[str, int]]: """ Return ``(resname, net_charge)`` for HETATM residues not in :data:`AMBER14_STANDARD`. ATOM records with unknown resnames warn. """ pdb = self._model.pdb nonstandard: List[Tuple[str, int]] = [] seen: set = set() records = pdb["ATOM"].astype(str).str.strip() resnames = pdb["resname"].astype(str).str.strip() for record, resname in zip(records, resnames): if resname in seen: continue seen.add(resname) if resname in AMBER14_STANDARD: continue if record == "HETATM": charge = self._residue_charges.get(resname, None) if charge is None: warnings.warn( f"[AmberTarget] Non-standard residue '{resname}' has no " f"charge in residue_charges; assuming 0. " f"Pass residue_charges={{'{resname}': <charge>}} to suppress.", UserWarning, stacklevel=4, ) charge = 0 nonstandard.append((resname, charge)) else: # ATOM record with unrecognised name warnings.warn( f"[AmberTarget] ATOM record with unrecognised residue name " f"'{resname}'. pdbfixer / tleap may or may not handle it.", UserWarning, stacklevel=4, ) return nonstandard # ------------------------------------------------------------------ # Step 2 — Antechamber pipeline # ------------------------------------------------------------------ @staticmethod def _cache_key(resname: str, atom_names: List[str], charge: int) -> str: content = f"{resname}:{':'.join(sorted(atom_names))}:{charge}" return hashlib.sha1(content.encode()).hexdigest() def _get_cache_dir(self, resname: str) -> Path: d = PATH_TORCHREF_DATA / "amber_cache" / resname d.mkdir(parents=True, exist_ok=True) return d def _write_residue_pdb(self, res_atoms, path: Path) -> None: """Write a minimal single-residue PDB file for antechamber input.""" with open(path, "w") as f: for serial, (_, row) in enumerate(res_atoms.iterrows(), 1): name = str(row["name"]).strip() resname = str(row["resname"]).strip() x, y, z = float(row["x"]), float(row["y"]), float(row["z"]) elem = str(row.get("element", name[0])).strip() chain = str(row.get("chainid", "A")).strip() or "A" resseq = int(row.get("resseq", 1)) f.write( f"HETATM{serial:5d} {name:<4s} {resname:3s} {chain}" f"{resseq:4d} " f"{x:8.3f}{y:8.3f}{z:8.3f}" f" 1.00 0.00 {elem:>2s}\n" ) f.write("END\n") def _run_antechamber_one( self, resname: str, charge: int ) -> Tuple[str, Path, Path]: """ Run antechamber + parmchk2 for one non-standard residue. Cache is checked first. On a miss, work happens in a temp dir and results are atomically moved to the cache (write-then-rename). """ pdb = self._model.pdb res_atoms = pdb[pdb["resname"].astype(str).str.strip() == resname] atom_names = res_atoms["name"].astype(str).str.strip().tolist() # antechamber's BCC charges require a fully protonated molecule so that # the semiempirical QM step (sqm) has an even electron count. # Detect odd-electron count early and emit a clear error. _Z = {"H":1,"He":2,"Li":3,"Be":4,"B":5,"C":6,"N":7,"O":8,"F":9,"Ne":10, "Na":11,"Mg":12,"Al":13,"Si":14,"P":15,"S":16,"Cl":17,"Ar":18, "K":19,"Ca":20,"Cr":24,"Mn":25,"Fe":26,"Co":27,"Ni":28,"Cu":29, "Zn":30,"Br":35,"I":53,"Se":34,"Mo":42,"W":74,"Pt":78,"Au":79} elems = res_atoms["element"].astype(str).str.strip().str.capitalize() n_protons = sum(_Z.get(e, 0) for e in elems) n_electrons = n_protons - charge if n_electrons % 2 != 0: # Odd electron count → sqm cannot converge. Almost certainly caused # by missing H atoms in an organic molecule. has_h = elems.isin(["H", "D"]).any() raise RuntimeError( f"[AmberTarget] Cannot run antechamber for '{resname}': " f"odd electron count ({n_electrons}) for charge {charge:+d}.\n" + ( "The model has no H atoms for this residue; sqm (inside antechamber) " "requires a fully protonated molecule.\n" "Fix: call model.generate_hydrogens() or load the PDB with " "strip_H=False before creating AmberTarget." if not has_h else f"Check the net charge (residue_charges={{'{resname}': <charge>}}) " f"and verify the protonation state." ) ) key = self._cache_key(resname, atom_names, charge) cache_dir = self._get_cache_dir(resname) mol2_cached = cache_dir / f"{key}.mol2" frcmod_cached = cache_dir / f"{key}.frcmod" if mol2_cached.exists() and frcmod_cached.exists(): if self.verbose >= 1: print(f"[AmberTarget] Cache hit: {resname} ({key[:8]}...)") return resname, mol2_cached, frcmod_cached if self.verbose >= 1: print(f"[AmberTarget] antechamber: {resname} (charge={charge:+d})") work_dir = Path(tempfile.mkdtemp(prefix=f"amber_{resname}_")) try: lig_pdb = work_dir / "lig.pdb" lig_mol2 = work_dir / "lig.mol2" lig_frcmod = work_dir / "lig.frcmod" self._write_residue_pdb(res_atoms, lig_pdb) # antechamber r = subprocess.run( [ _find_ambertools_binary("antechamber"), "-i", str(lig_pdb), "-fi", "pdb", "-o", str(lig_mol2), "-fo", "mol2", "-c", "bcc", "-nc", str(charge), "-s", "2", "-at", "gaff2", "-dr", "no", ], cwd=str(work_dir), capture_output=True, text=True, timeout=600, ) if r.returncode != 0 or not lig_mol2.exists(): raise RuntimeError( f"antechamber failed for '{resname}':\n" f"STDOUT: {r.stdout}\nSTDERR: {r.stderr}" ) # parmchk2 r = subprocess.run( [ _find_ambertools_binary("parmchk2"), "-i", str(lig_mol2), "-f", "mol2", "-o", str(lig_frcmod), "-s", "gaff2", ], cwd=str(work_dir), capture_output=True, text=True, timeout=120, ) if r.returncode != 0 or not lig_frcmod.exists(): raise RuntimeError( f"parmchk2 failed for '{resname}':\n" f"STDOUT: {r.stdout}\nSTDERR: {r.stderr}" ) # Atomic cache write (temp file → rename) shutil.copy2(lig_mol2, cache_dir / f"{key}.mol2.tmp") shutil.copy2(lig_frcmod, cache_dir / f"{key}.frcmod.tmp") (cache_dir / f"{key}.mol2.tmp").rename(mol2_cached) (cache_dir / f"{key}.frcmod.tmp").rename(frcmod_cached) (cache_dir / f"{key}.meta.json").write_text( json.dumps( { "resname": resname, "charge": charge, "atom_names": sorted(atom_names), "cache_key": key, }, indent=2, ) ) if self.verbose >= 1: print(f"[AmberTarget] Cached: {resname}{cache_dir}") return resname, mol2_cached, frcmod_cached finally: shutil.rmtree(work_dir, ignore_errors=True) def _run_antechamber_parallel( self, nonstandard: List[Tuple[str, int]] ) -> Dict[str, Tuple[Path, Path]]: """Resolve GAFF2 parameters for non-standard residues. Checks (in order): user-supplied gaff2_files → cache → antechamber. """ if not nonstandard: return {} results: Dict[str, Tuple[Path, Path]] = {} need_antechamber: List[Tuple[str, int]] = [] for rn, charge in nonstandard: # 1. User-supplied files if rn in self._gaff2_files: mol2, frcmod = self._gaff2_files[rn] if self.verbose >= 1: print(f"[AmberTarget] Using supplied files for '{rn}'") results[rn] = (Path(mol2), Path(frcmod)) else: need_antechamber.append((rn, charge)) if not need_antechamber: return results # 2. Cache + antechamber for remaining residues with ThreadPoolExecutor(max_workers=min(len(need_antechamber), 4)) as pool: futures = { pool.submit(self._run_antechamber_one, rn, ch): rn for rn, ch in need_antechamber } for fut in as_completed(futures): rn = futures[fut] try: rn_out, mol2, frcmod = fut.result() results[rn_out] = (mol2, frcmod) except Exception as exc: raise RuntimeError( f"[AmberTarget] Failed to parameterise '{rn}': {exc}" ) from exc return results # ------------------------------------------------------------------ # Step 3 — Build OpenMM system # ------------------------------------------------------------------ def _filter_pdb_for_omm(self, include_nonstandard: bool = False): """ Return a filtered copy of model.pdb suitable for OpenMM / tleap: - Primary conformation only (altloc == '' or 'A') - Heavy atoms only (element != H or D) - Optionally exclude non-standard residues (standard path) The returned DataFrame keeps the original model.pdb integer index so that ``df.index`` can be used as model row indices in the atom map. """ pdb = self._model.update_pdb() mask = pdb["altloc"].astype(str).str.strip().isin(["", "A"]) mask &= ~pdb["element"].astype(str).str.strip().isin(["H", "D"]) if not include_nonstandard: ns_resnames = { rn for rn in pdb["resname"].astype(str).str.strip().unique() if rn not in _MODELLER_FF_RESIDUES } if ns_resnames: mask &= ~pdb["resname"].astype(str).str.strip().isin(ns_resnames) # Do NOT reset_index: keep original model.pdb row positions as index return pdb[mask].copy() def _filter_pdb_for_tleap(self): """ Filter model.pdb for the tleap protein PDB (GAFF2 path): - Primary conformation only (altloc == '' or 'A') - Heavy atoms only (element != H or D) - Standard AMBER residues only (``AMBER14_STANDARD``) — non-standard HETATM residues are handled via antechamber / mol2 separately - Waters excluded (``_TLEAP_EXCLUDE_RESIDUES``) — tleap reorders waters, breaking sequential atom-map strategy; no gradient loss since crystal waters are not primary refinement targets - Monatomic ions (MG, ZN, CA, …) ARE included — covered by ``leaprc.water.tip3p`` (Li/Merz 12-6 set), appear in fixed PDB order, important for electrostatics near charged ligands - Terminal atoms tleap regenerates (OXT …) excluded Note: uses ``AMBER14_STANDARD`` (not ``_MODELLER_FF_RESIDUES``) so that ions absent from amber14-all.xml are still sent to tleap. Index is preserved (original model.pdb row positions). """ pdb = self._model.update_pdb() mask = pdb["altloc"].astype(str).str.strip().isin(["", "A"]) mask &= ~pdb["element"].astype(str).str.strip().isin(["H", "D"]) # Allow AMBER-standard residues; exclude HOH/WAT and non-standard HETATM res_col = pdb["resname"].astype(str).str.strip() tleap_allowed = AMBER14_STANDARD - _TLEAP_EXCLUDE_RESIDUES mask &= res_col.isin(tleap_allowed) # Strip tleap-regenerated terminal atoms mask &= ~pdb["name"].astype(str).str.strip().isin(_TLEAP_SKIP_ATOMS) return pdb[mask].copy() def _build_omm_system( self, gaff2_params: Dict[str, Tuple[Path, Path]] ) -> Tuple: """ Build OpenMM system. Returns ``(system, omm_topology, pos_nm_array)``. Standard path (no non-standard residues) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Filter model PDB → heavy atoms, primary conformation, standard residues. Use ``openmm.app.Modeller.addHydrogens()`` to re-add H with AMBER names. Create system with ``ForceField('amber14-all.xml')``. GAFF2 path (non-standard residues present) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Write protein PDB (no OXT, no H) + mol2 per ligand. Combine via tleap ``combine{}`` command → prmtop/inpcrd. Load with parmed → ``AmberParm.createSystem()``. """ import openmm as mm # noqa: PLC0415 import openmm.app as app # noqa: PLC0415 import openmm.unit as unit # noqa: PLC0415 cutoff_A = float(self._cutoff_buf.item()) if not gaff2_params: system, topology, pos_nm = self._build_standard(cutoff_A, app, unit) else: system, topology, pos_nm = self._build_gaff2( gaff2_params, cutoff_A, app, unit ) # Remove CMMotionRemover so raw per-atom forces are available for i in range(system.getNumForces() - 1, -1, -1): if isinstance(system.getForce(i), mm.CMMotionRemover): system.removeForce(i) return system, topology, pos_nm def _build_standard(self, cutoff_A: float, app, unit) -> Tuple: """ AMBER14 standard-residue path using gemmi + pdbfixer + OpenMM. gemmi writes proper chain termination / TER records so pdbfixer can detect and fix missing terminal atoms (OXT). pdbfixer also handles missing sidechain atoms and non-standard residue names. """ import gemmi # noqa: PLC0415 from pdbfixer import PDBFixer # noqa: PLC0415 from torchref.io import pdb as pdbio # noqa: PLC0415 # Standard path: Modeller preserves chain/resseq → use key-based mapping self._tleap_residue_map = None pdb_heavy = self._filter_pdb_for_omm(include_nonstandard=False) tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) tmp2 = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) tmp.close() tmp2.close() try: # Write via torchref, then re-read/write with gemmi to get # proper chain breaks and TER records that pdbfixer needs. pdbio.write(pdb_heavy, tmp.name) st = gemmi.read_structure(tmp.name) st.setup_entities() st.assign_subchains() st.write_pdb(tmp2.name) # pdbfixer: add missing terminal atoms and sidechain atoms fixer = PDBFixer(filename=tmp2.name) fixer.findMissingResidues() fixer.missingResidues = {} # don't fill gaps fixer.findMissingAtoms() if self.verbose >= 1: n_missing = sum(len(v) for v in fixer.missingAtoms.values()) n_terminals = sum( 1 for v in fixer.missingTerminals.values() if v ) if n_missing or n_terminals: print( f"[AmberTarget] pdbfixer: {n_missing} missing atoms, " f"{n_terminals} terminal fixes" ) fixer.addMissingAtoms() finally: os.unlink(tmp.name) os.unlink(tmp2.name) ff = app.ForceField("amber14-all.xml", "amber14/tip3pfb.xml") modeller = app.Modeller(fixer.topology, fixer.positions) modeller.addHydrogens(ff) system = ff.createSystem( modeller.topology, nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=cutoff_A * unit.angstrom, constraints=None, ) # positions in nm pos_nm = np.array( modeller.positions.value_in_unit(unit.nanometer), dtype=np.float64 ) return system, modeller.topology, pos_nm def _build_gaff2( self, gaff2_params: Dict[str, Tuple[Path, Path]], cutoff_A: float, app, unit, ) -> Tuple: """ AMBER14 + GAFF2 path via tleap + parmed. All AMBER14-standard heavy atoms (protein, ions, waters — no OXT, no H, no non-standard HETATM) plus each ligand mol2 are combined by tleap ``combine{}``. parmed loads the resulting prmtop/inpcrd. Atom mapping uses position-based matching (see :meth:`_build_atom_map`): tleap's initial coordinates are taken directly from the PDB we write, so model and tleap positions agree to 3 decimal places (PDB precision), making a KD-tree nearest-neighbour search unambiguous. This avoids relying on tleap's residue-sequential numbering, which is fragile for water molecules. """ import parmed as pmd # noqa: PLC0415 from torchref.io import pdb as pdbio # noqa: PLC0415 work_dir = Path(tempfile.mkdtemp(prefix="amber_gaff2_")) try: prot_pdb = work_dir / "protein.pdb" pdb_tleap = self._filter_pdb_for_tleap() # Signal GAFF2 path to _build_atom_map (position-based + name fallback) self._tleap_residue_map = True # type: ignore[assignment] # Store GAFF2 resnames so _build_atom_map can do name-based fallback # for ligand atoms (mol2 may have old coords if model was refined first) self._gaff2_resnames: set = set(gaff2_params.keys()) pdbio.write(pdb_tleap.reset_index(drop=True), str(prot_pdb)) prmtop = work_dir / "complex.prmtop" inpcrd = work_dir / "complex.inpcrd" # Build tleap source lines + one mol2 load per ligand lig_loads = [] lig_names = [] for rn, (mol2, frcmod) in gaff2_params.items(): lig_loads.append(f"loadAmberParams {frcmod}") lig_loads.append(f"{rn} = loadMol2 {mol2}") lig_names.append(rn) combine_list = " ".join(["protein"] + lig_names) tleap_script = "\n".join( [ "source leaprc.protein.ff14SB", "source leaprc.water.tip3p", "source leaprc.gaff2", ] + lig_loads + [ f"protein = loadPdb {prot_pdb}", f"complex = combine {{{combine_list}}}", f"saveAmberParm complex {prmtop} {inpcrd}", "quit", ] ) (work_dir / "tleap.in").write_text(tleap_script + "\n") r = subprocess.run( [_find_ambertools_binary("tleap"), "-f", str(work_dir / "tleap.in")], cwd=str(work_dir), capture_output=True, text=True, timeout=300, ) if not inpcrd.exists(): raise RuntimeError( f"[AmberTarget] tleap failed (GAFF2 path).\n" f"STDOUT (last 1000 chars):\n{r.stdout[-1000:]}\n" f"STDERR: {r.stderr[-500:]}" ) combined = pmd.amber.AmberParm(str(prmtop), str(inpcrd)) system = combined.createSystem( nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=cutoff_A * unit.angstrom, constraints=None, ) topology = combined.topology pos_nm = np.array( combined.positions.value_in_unit(unit.nanometer), dtype=np.float64 ) return system, topology, pos_nm finally: shutil.rmtree(work_dir, ignore_errors=True) # ------------------------------------------------------------------ # Step 4 — Atom map (model index → OpenMM index) # ------------------------------------------------------------------ @staticmethod def _is_hydrogen(omm_atom) -> bool: elem = omm_atom.element if elem is not None: return elem.symbol == "H" return omm_atom.name.startswith("H") # heuristic fallback def _build_atom_map(self) -> None: """ Build ``self._model_to_omm``: int32 array [n_model] where entry *i* is the OpenMM atom index corresponding to model atom *i*, or -1 for unmatched atoms (H atoms, altloc-B atoms, non-standard HETATM, …). Two strategies depending on how the system was built: **Standard path** (``_tleap_residue_map is None``): OpenMM Modeller preserves chain IDs and residue numbers from the input PDB, so matching uses the key ``(chain_id, resseq, icode, atom_name)``. **GAFF2 path** (``_tleap_residue_map is not None``): tleap strips chain IDs and renumbers residues sequentially, making name/number-based matching unreliable (especially for waters). Instead, the tleap initial positions are taken from the exact coordinates we wrote to the PDB (via ``update_pdb()``), so model and tleap positions agree to within PDB precision (0.001 Å = 0.0001 nm). A KD-tree nearest-neighbour search with a tight threshold (0.005 nm) unambiguously identifies each tleap heavy atom's model counterpart. """ from scipy.spatial import cKDTree # noqa: PLC0415 pdb = self._model.pdb n_model = len(pdb) model_to_omm = np.full(n_model, -1, dtype=np.int32) if self._tleap_residue_map is None: # ---- Standard path: match by (chain, resseq, icode, atom_name) ---- # No altlocs at this point (checked in _build). model_key_to_idx: Dict[Tuple, int] = {} for i in range(n_model): row = pdb.iloc[i] key = ( str(row["chainid"]).strip(), int(row["resseq"]), str(row.get("icode", "")).strip(), str(row["name"]).strip(), ) model_key_to_idx[key] = i for omm_atom in self._topology.atoms(): if self._is_hydrogen(omm_atom): continue chain_id = omm_atom.residue.chain.id.strip() try: resseq = int(omm_atom.residue.id) except ValueError: raw = omm_atom.residue.id.strip() resseq = int(raw.rstrip("ABCDEFGHIJKLMNOPQRSTUVWXYZ") or "0") icode = (omm_atom.residue.insertionCode or "").strip() idx = model_key_to_idx.get( (chain_id, resseq, icode, omm_atom.name.strip()) ) if idx is not None: model_to_omm[idx] = omm_atom.index else: # ---- GAFF2 path: position-based matching via KD-tree ---- # Collect tleap heavy-atom positions (nm) and their indices. tleap_pos_nm = self._tleap_pos_nm # set by _build() before this call tleap_ha_omm_idx: List[int] = [] tleap_ha_pos: List[np.ndarray] = [] for omm_atom in self._topology.atoms(): if not self._is_hydrogen(omm_atom): tleap_ha_omm_idx.append(omm_atom.index) tleap_ha_pos.append(tleap_pos_nm[omm_atom.index]) tleap_ha_pos_arr = np.array(tleap_ha_pos) # (N_tleap_heavy, 3) nm tree = cKDTree(tleap_ha_pos_arr) # Collect model primary-altloc heavy-atom positions (nm) and indices. # Use update_pdb() coords — same values that were written to tleap PDB. fresh_pdb = self._model.update_pdb() altloc_ok = fresh_pdb["altloc"].astype(str).str.strip().isin(["", "A"]) not_h = ~fresh_pdb["element"].astype(str).str.strip().isin(["H", "D"]) primary_heavy = np.where((altloc_ok & not_h).values)[0] model_pos_nm = np.column_stack([ fresh_pdb["x"].values[primary_heavy], fresh_pdb["y"].values[primary_heavy], fresh_pdb["z"].values[primary_heavy], ]) * 0.1 # Å → nm # Match: threshold = 0.005 nm (50× PDB precision of 0.0001 nm) dists, nn_idx = tree.query(model_pos_nm, k=1) matched = dists < 0.005 for local_i, (model_i, nn_i) in enumerate(zip(primary_heavy, nn_idx)): if matched[local_i]: model_to_omm[model_i] = tleap_ha_omm_idx[nn_i] # Name-based fallback for GAFF2 ligand residues whose mol2 positions # differ from the current model (e.g. after refinement steps). # The cached mol2 retains original antechamber coordinates, so a # second AmberTarget init after LBFGS will have position shifts. gaff2_resnames = getattr(self, "_gaff2_resnames", set()) if gaff2_resnames: # Build (resname, atom_name) → model positional index for primary heavy lig_key_to_model: Dict[Tuple[str, str], int] = {} for arr_pos in primary_heavy: rn = str(fresh_pdb["resname"].values[arr_pos]).strip() if rn not in gaff2_resnames: continue aname = str(fresh_pdb["name"].values[arr_pos]).strip() lig_key_to_model[(rn, aname)] = int(arr_pos) for omm_atom in self._topology.atoms(): if self._is_hydrogen(omm_atom): continue rn = omm_atom.residue.name if rn not in gaff2_resnames: continue aname = omm_atom.name.strip() model_arr_pos = lig_key_to_model.get((rn, aname)) if model_arr_pos is not None and model_to_omm[model_arr_pos] < 0: model_to_omm[model_arr_pos] = omm_atom.index # Warn about UNEXPECTED unmatched heavy atoms. # Expected to be unmatched (silently skipped in gradient): # - H / D atoms # - Waters, ions excluded from tleap (_TLEAP_EXCLUDE_RESIDUES) # - C-terminal OXT regenerated by tleap (_TLEAP_SKIP_ATOMS) # - Alternate conformer atoms (altloc != '' and != 'A') elem_col = pdb["element"].astype(str).str.strip() altloc_col = pdb["altloc"].astype(str).str.strip() resname_col = pdb["resname"].astype(str).str.strip() name_col = pdb["name"].astype(str).str.strip() heavy_mask = ~elem_col.isin(["H", "D"]) # Residues in AMBER14_STANDARD but without an amber14-all.xml template: # unmatched on the standard (Modeller) path; matched via tleap on GAFF2 path. _no_modeller_template = AMBER14_STANDARD - _MODELLER_FF_RESIDUES expected_mask = ( # Waters always excluded from tleap; no AMBER gradient expected resname_col.isin(_TLEAP_EXCLUDE_RESIDUES) | # Ions that lack Modeller templates (matched in GAFF2 path, not standard) resname_col.isin(_no_modeller_template) | # tleap-regenerated terminal atoms (OXT etc.) name_col.isin(_TLEAP_SKIP_ATOMS) | # alternate conformers (altloc B, C, …) (~altloc_col.isin(["", "A"])) ) unexpected_unmatched = np.where( heavy_mask.values & ~expected_mask.values & (model_to_omm < 0) )[0] if len(unexpected_unmatched) > 0: ex = [ f"{pdb.iloc[i]['name'].strip()} " f"({pdb.iloc[i]['resname'].strip()} {pdb.iloc[i]['resseq']})" for i in unexpected_unmatched[:5] ] warnings.warn( f"[AmberTarget] {len(unexpected_unmatched)} heavy model atom(s) " f"could not be matched to OpenMM topology " f"(e.g. {', '.join(ex)}). Their gradients will be zero.", UserWarning, stacklevel=3, ) elif self.verbose >= 2: unmatched_heavy = int(heavy_mask.values.sum()) - int( (heavy_mask.values & (model_to_omm >= 0)).sum() ) print( f"[AmberTarget] {unmatched_heavy} heavy atoms have model_to_omm=-1 " f"(expected: non-standard HETATM / altloc-B / OXT)" ) self._model_to_omm = model_to_omm self._n_omm_atoms = self._system.getNumParticles() if self.verbose >= 2: matched = int((model_to_omm >= 0).sum()) print( f"[AmberTarget] atom map: {matched}/{n_model} model atoms matched " f"({self._n_omm_atoms} total OpenMM atoms)" ) # ------------------------------------------------------------------ # Step 5 — OpenMM Context # ------------------------------------------------------------------ def _build_context(self, pos_nm: np.ndarray) -> None: """ Create an OpenMM Context on the platform that matches the model's device. Mapping: ``model.device.type == 'cuda'`` → CUDA, otherwise CPU. Falls back CUDA → OpenCL → CPU if the preferred platform is unavailable. """ import openmm as mm # noqa: PLC0415 device_type = getattr(self._model.device, "type", "cpu") preferred = "CUDA" if device_type == "cuda" else "CPU" seen: set = set() platforms = [ p for p in [preferred, "OpenCL", "CPU"] if not (p in seen or seen.add(p)) # type: ignore[func-returns-value] ] for name in platforms: try: platform = mm.Platform.getPlatformByName(name) integrator = mm.VerletIntegrator(1.0) context = mm.Context(self._system, integrator, platform) context.setPositions(pos_nm) # Warmup + validation context.getState(getEnergy=True, getForces=True) self._context = context self._platform_name = name if self.verbose >= 1: print(f"[AmberTarget] OpenMM platform: {name}") return except Exception as exc: if self.verbose >= 1: print(f"[AmberTarget] Platform {name} unavailable: {exc}") raise RuntimeError( f"[AmberTarget] No usable OpenMM platform (tried {platforms})." ) # ------------------------------------------------------------------ # forward # ------------------------------------------------------------------
[docs] def forward(self) -> torch.Tensor: """ Compute AMBER14 energy for current model coordinates. Returns ------- torch.Tensor Scalar energy in kJ/mol (or kJ/mol/atom if normalize_by_atoms). Gradient flows to ``model.xyz`` via OpenMM analytical forces. """ if self._context is None: raise RuntimeError( "[AmberTarget] Not initialised. Pass model= to constructor." ) xyz = self._model.xyz() # (n_model_atoms, 3), Å energy = _OpenMMAMBERFunction.apply( xyz, self._context, self._model_to_omm, self._pos_buf, ) if self._normalize: energy = energy / self._n_model_atoms return energy
# ------------------------------------------------------------------ # stats # ------------------------------------------------------------------
[docs] def stats(self) -> Dict[str, "StatEntry"]: """Return target statistics for the logging pipeline.""" with torch.no_grad(): e_per_atom = self.forward().item() e_total = ( e_per_atom * self._n_model_atoms if self._normalize else e_per_atom ) return { "loss": stat(e_per_atom, VERBOSITY_STANDARD), "energy_kJ_mol": stat(e_total, VERBOSITY_DETAILED), "n_atoms": stat(self._n_model_atoms, VERBOSITY_DEBUG), "platform": stat(self._platform_name, VERBOSITY_DETAILED), "n_nonstandard": stat(self._n_nonstandard, VERBOSITY_DEBUG), }