Source code for torchref.refinement.targets.geometry.ramachandran

import torch
from typing import TYPE_CHECKING, Dict

from torchref.base.targets.ramachandran import ramachandran_math
from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import GeometryTarget

if TYPE_CHECKING:
    from torchref.model.model import Model


[docs] class RamachandranTarget(GeometryTarget): """ Ramachandran restraint via pre-computed NLL surfaces. Uses 6 residue-type-dependent NLL surfaces (general, glycine, cis-proline, trans-proline, pre-proline, ile/val) at 1-degree resolution. The loss is computed by bilinear interpolation of the NLL surface at the current (phi, psi) angles. The surfaces store NLL = -log P(phi, psi | residue_type), so favored regions have low values and outlier regions have high values — consistent with all other geometry targets. """ name: str = "geometry/ramachandran"
[docs] def __init__(self, model: "Model" = None, verbose: int = 0): super().__init__(model, verbose, target_value=2.0, sigma=1.0)
[docs] def forward(self) -> torch.Tensor: xyz = self.model.xyz() if not hasattr(self.restraints, "_rama_phi_indices") or self.restraints._rama_phi_indices is None: return torch.tensor(0.0, device=xyz.device) return ramachandran_math( xyz, self.restraints._rama_phi_indices, self.restraints._rama_psi_indices, self.restraints._rama_surfaces, self.restraints._rama_surface_type, )
[docs] def stats(self) -> Dict[str, StatEntry]: """Get Ramachandran restraint statistics.""" if not hasattr(self.restraints, "_rama_phi_indices") or self.restraints._rama_phi_indices is None: return {} n = self.restraints._rama_phi_indices.shape[0] loss = self.forward() return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "n": stat(n, VERBOSITY_DEBUG), }