Source code for torchref.base.targets.adp

"""ADP (B-factor) restraint NLLs: similarity, KL-divergence, locality."""

import torch

from ._common import LOG_2PI
from ._dispatch import use_triton


def _adp_simu_math_eager(
    b: torch.Tensor,
    pair_indices: torch.Tensor,
    simu_sigma: torch.Tensor,
) -> torch.Tensor:
    diffs = b[pair_indices[:, 0]] - b[pair_indices[:, 1]]
    nll = (
        0.5 * (diffs / simu_sigma) ** 2
        + torch.log(simu_sigma)
        + 0.5 * LOG_2PI
    )
    return nll.sum()


[docs] def adp_simu_math( b: torch.Tensor, pair_indices: torch.Tensor, simu_sigma: torch.Tensor, ) -> torch.Tensor: """ADP similarity (SIMU) NLL on bonded-atom B-factor differences. Dispatches to :func:`torchref.base.targets.triton.adp_simu_math_triton` on CUDA float32 (~1.6× faster fwd+bw on A100). Falls back to eager otherwise. Parameters ---------- b : torch.Tensor (N_atoms,) B-factors. pair_indices : torch.Tensor (N, 2) bonded-atom pairs to compare. simu_sigma : torch.Tensor Scalar sigma on the difference (a buffer in the target). """ if use_triton(b): from .triton.adp_simu import adp_simu_math_triton return adp_simu_math_triton(b, pair_indices, simu_sigma) return _adp_simu_math_eager(b, pair_indices, simu_sigma)
[docs] def adp_kl_math( log_adp: torch.Tensor, target_log_std: float = 0.2, ) -> torch.Tensor: """KL divergence regularizer on log(B). Mirrors ``Model.adp_kl_divergence_loss``: KL between an empirical Gaussian (mean fixed to current ``mean(log_adp)`` detached, std = ``std(log_adp)``) and a target Gaussian with the same mean but fixed std. """ sigma_data = torch.std(log_adp) log_sigma_ratio = torch.log( torch.tensor(target_log_std, device=log_adp.device, dtype=log_adp.dtype) / sigma_data ) variance_ratio = (sigma_data ** 2) / (2 * target_log_std ** 2) return log_sigma_ratio + variance_ratio - 0.5
[docs] def adp_locality_math( b: torch.Tensor, neighbor_indices: torch.Tensor, neighbor_distances: torch.Tensor, ) -> torch.Tensor: """ADP locality NLL: weighted MSE on log(B) differences with KNN. Mirrors ``ADPLocalityTarget.forward``. Neighbor list construction is the target's bookkeeping and is not included here. """ log_adp = torch.log(b.clamp(min=1e-3)) neighbor_log_adp = log_adp[neighbor_indices] diff = log_adp.unsqueeze(1) - neighbor_log_adp weights = 1.0 / (neighbor_distances + 1e-6) weighted_sq_diff = weights * (diff / 0.5) ** 2 return weighted_sq_diff.sum()