Source code for torchref.base.targets.ramachandran

"""Ramachandran restraint via bilinear interpolation on a precomputed NLL surface."""

import torch

from ._common import torsions_from_xyz
from ._dispatch import use_triton


def _ramachandran_math_eager(
    xyz: torch.Tensor,
    phi_idx: torch.Tensor,
    psi_idx: torch.Tensor,
    nll_surfaces: torch.Tensor,
    surface_type: torch.Tensor,
) -> torch.Tensor:
    phi_deg = -torsions_from_xyz(xyz, phi_idx)
    psi_deg = -torsions_from_xyz(xyz, psi_idx)

    phi_idx_grid = (phi_deg + 180.0) % 360.0
    psi_idx_grid = (psi_deg + 180.0) % 360.0

    phi_lo = phi_idx_grid.detach().floor().long() % 360
    phi_hi = (phi_lo + 1) % 360
    psi_lo = psi_idx_grid.detach().floor().long() % 360
    psi_hi = (psi_lo + 1) % 360
    phi_frac = phi_idx_grid - phi_idx_grid.detach().floor()
    psi_frac = psi_idx_grid - psi_idx_grid.detach().floor()

    s = surface_type
    v00 = nll_surfaces[s, phi_lo, psi_lo]
    v01 = nll_surfaces[s, phi_lo, psi_hi]
    v10 = nll_surfaces[s, phi_hi, psi_lo]
    v11 = nll_surfaces[s, phi_hi, psi_hi]

    nll = (
        (1 - phi_frac) * (1 - psi_frac) * v00
        + (1 - phi_frac) * psi_frac * v01
        + phi_frac * (1 - psi_frac) * v10
        + phi_frac * psi_frac * v11
    )
    return nll.sum()


[docs] def ramachandran_math( xyz: torch.Tensor, phi_idx: torch.Tensor, psi_idx: torch.Tensor, nll_surfaces: torch.Tensor, surface_type: torch.Tensor, ) -> torch.Tensor: """Ramachandran bilinear-interpolated NLL. Mirrors ``RamachandranTarget.forward``. Dispatches to :func:`torchref.base.targets.triton.ramachandran_math_triton` on CUDA float32 (~10× faster fwd+bw on A100). Falls back to eager otherwise. Parameters ---------- xyz : torch.Tensor (N_atoms, 3) Cartesian coordinates. phi_idx, psi_idx : torch.Tensor (N, 4) atom indices for the two backbone dihedrals. nll_surfaces : torch.Tensor (n_surface_types, 360, 360) precomputed NLL = -log P(φ, ψ | type). surface_type : torch.Tensor (N,) integer type per residue. """ if use_triton(xyz): from .triton.ramachandran import ramachandran_math_triton return ramachandran_math_triton( xyz, phi_idx, psi_idx, nll_surfaces, surface_type, ) return _ramachandran_math_eager( xyz, phi_idx, psi_idx, nll_surfaces, surface_type, )