Source code for torchref.refinement.targets.forcefield_target

"""
Force Field Target for Molecular Dynamics Energy Calculations.

This module provides a target that computes molecular energy using
TorchMD-Net neural network potentials for crystallographic refinement.
"""

from typing import TYPE_CHECKING, Dict, Optional
import torch
from torch import nn

from .base import ModelTarget
from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

if TYPE_CHECKING:
    from torchref.model.model import Model


[docs] class ForceFieldTarget(ModelTarget): """ Force field energy target using TorchMD-Net ML potentials. Computes molecular energy from atomic coordinates using a pre-trained neural network potential. Returns energy as a differentiable tensor suitable for gradient-based refinement. Parameters ---------- model : Model, optional Reference to the Model object. Must have hydrogens (load with strip_H=False). model_path : str, optional Path to TorchMD-Net checkpoint file (.ckpt). cutoff : float, optional Interaction cutoff distance in Angstroms. Default is 5.0. normalize_by_atoms : bool, optional If True, return energy per atom. Default is True. verbose : int, optional Verbosity level. Default is 0. Examples -------- >>> from torchref.model import Model >>> from torchref.refinement.targets import ForceFieldTarget >>> >>> # Load model WITH hydrogens >>> model = Model(strip_H=False) >>> model.load_pdb('structure_with_H.pdb') >>> >>> # Create force field target >>> ff_target = ForceFieldTarget( ... model=model, ... model_path='path/to/torchmdnet.ckpt', ... ) >>> >>> # Get energy >>> energy = ff_target() Notes ----- Requires torchmd-net package: pip install torchmd-net Pre-trained models available at: https://github.com/torchmd/torchmd-net/tree/main/examples """ name: str = "forcefield"
[docs] def __init__( self, model: "Model" = None, model_path: str = None, cutoff: float = 5.0, normalize_by_atoms: bool = True, verbose: int = 0, ): super().__init__(model=model, verbose=verbose) # Store configuration self._model_path = model_path self._normalize_by_atoms = normalize_by_atoms # Register cutoff as buffer for state_dict compatibility self.register_buffer("_cutoff", torch.tensor(cutoff)) # Neural network potential (lazy initialization) self._nn_potential = None
def _ensure_nn_potential(self) -> None: """Initialize the TorchMD-Net model on first use.""" if self._nn_potential is not None: return # Validate model path if self._model_path is None: raise ValueError( "model_path is required. Provide path to a TorchMD-Net checkpoint (.ckpt).\n" "Pre-trained models: https://github.com/torchmd/torchmd-net/tree/main/examples" ) # Import TorchMD-Net try: from torchmdnet.models.model import load_model except ImportError: raise ImportError( "ForceFieldTarget requires torchmd-net package.\n" "Install with: pip install torchref[forcefield]\n" "Pre-trained models: https://github.com/torchmd/torchmd-net/tree/main/examples" ) from None # Load model self._nn_potential = load_model(self._model_path) # Move to same device as atomic model if self.model is not None: try: device = next(self.model.parameters()).device self._nn_potential = self._nn_potential.to(device) except StopIteration: pass # Model has no parameters yet def _validate_hydrogens(self) -> None: """Warn if model appears to lack hydrogens.""" if self.model is None: return # Check if any hydrogen atoms present Z = self.model.Z has_hydrogens = (Z == 1).any().item() if not has_hydrogens and self.verbose > 0: import warnings warnings.warn( "Model appears to have no hydrogen atoms. " "TorchMD-Net typically requires all-atom structures. " "Load with Model(strip_H=False) if hydrogens are needed.", UserWarning )
[docs] def forward(self) -> torch.Tensor: """ Compute force field energy for current model coordinates. Returns ------- torch.Tensor Scalar energy tensor with gradient support. """ # Ensure neural network is loaded self._ensure_nn_potential() # Validate on first call if not hasattr(self, '_validated'): self._validate_hydrogens() self._validated = True # Get coordinates and atomic numbers xyz = self.model.xyz() # Shape: (n_atoms, 3) Z = self.model.Z # Shape: (n_atoms,) # Ensure Z is long tensor if Z.dtype != torch.long: Z = Z.long() # Create batch tensor (single structure = all zeros) batch = torch.zeros(len(Z), dtype=torch.long, device=xyz.device) # Compute energy via TorchMD-Net # Returns (energy, forces) or just energy depending on model config result = self._nn_potential(Z, xyz, batch) if isinstance(result, tuple): energy = result[0] else: energy = result # Ensure scalar if energy.dim() > 0: energy = energy.sum() # Normalize by number of atoms if requested if self._normalize_by_atoms: energy = energy / len(Z) return energy
[docs] def stats(self) -> Dict[str, StatEntry]: """ Get statistics for this target. Returns ------- dict Dictionary with StatEntry values. """ with torch.no_grad(): energy = self.forward().item() n_atoms = len(self.model.Z) if self.model is not None else 0 return { "loss": stat(energy, VERBOSITY_STANDARD), "n_atoms": stat(n_atoms, VERBOSITY_DEBUG), "model_path": stat(str(self._model_path), VERBOSITY_DETAILED), }