Source code for torchref.refinement.targets.adp.similarity

import numpy as np
import torch
from typing import TYPE_CHECKING, Dict

from torchref.base.targets.adp import adp_simu_math
from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import ADPTarget
from ..base import adp_similarity_nll

if TYPE_CHECKING:
    from torchref.model.model import Model


[docs] class ADPSimilarityTarget(ADPTarget): """ ADP Similarity restraint (SIMU in Phenix/SHELX). Restrains B-factors of bonded atoms to be similar. NLL = 0.5 * ((B_i - B_j) / σ)² + log(σ) + 0.5 * log(2π) Tunable parameters (as buffers): - _simu_sigma: float, sigma for B-factor differences (default 2.0 Ų) """ name: str = "adp/simu"
[docs] def __init__( self, model: "Model" = None, simu_sigma: float = 2.0, verbose: int = 0 ): super().__init__(model, verbose, target_value=4.0, sigma=1.2) # Register simu-specific sigma as buffer (separate from base sigma) self.register_buffer("_simu_sigma", torch.tensor(simu_sigma))
@property def simu_sigma(self) -> float: """Get SIMU sigma value.""" return self._simu_sigma.item() @simu_sigma.setter def simu_sigma(self, value: float): """Set SIMU sigma value.""" self._simu_sigma.fill_(value) def _get_pair_indices(self) -> torch.Tensor: """Concatenate non-"all" bond restraint origins into a single (N, 2) tensor for the SIMU pair list. Cached after first build.""" cached = getattr(self, "_simu_pair_indices_cache", None) if cached is not None: return cached chunks = [] for origin, group in self.restraints.restraints.get("bond", {}).items(): if origin == "all": continue idx_ = group.get("indices") if idx_ is not None and len(idx_) > 0: chunks.append(idx_) if chunks: cached = torch.cat(chunks, dim=0).contiguous() else: cached = torch.empty(0, 2, dtype=torch.long, device=self.model.xyz().device) self._simu_pair_indices_cache = cached return cached
[docs] def forward(self) -> torch.Tensor: # Use the adp_simu_math dispatcher (Triton on CUDA fp32). pair_indices = self._get_pair_indices() adp_t = self.model.adp() if pair_indices.shape[0] == 0: return torch.zeros((), device=adp_t.device, dtype=adp_t.dtype) # Lazily move the ``_simu_sigma`` buffer onto the model's device # the first time we reach here. Once moved, subsequent forwards # (and CUDA-Graph captures) skip the device transfer — calling # ``.to()`` on a CPU buffer inside a capture region triggers a # ``cudaErrorStreamCaptureUnsupported``. if (self._simu_sigma.device != adp_t.device or self._simu_sigma.dtype != adp_t.dtype): self._simu_sigma = self._simu_sigma.to( device=adp_t.device, dtype=adp_t.dtype, ) return adp_simu_math(adp_t, pair_indices, self._simu_sigma)
[docs] def stats(self) -> Dict[str, any]: """Get SIMU restraint statistics.""" b_diffs = self.restraints.adp_b_differences() if len(b_diffs) == 0: return {} b_diffs_abs = b_diffs.abs() z_scores = b_diffs_abs / self.simu_sigma loss = self.forward() return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "count": stat(len(b_diffs), VERBOSITY_DEBUG), "rms_delta_b": stat( torch.sqrt((b_diffs**2).mean()).item(), VERBOSITY_DETAILED ), "mean_delta_b": stat(b_diffs_abs.mean().item(), VERBOSITY_DETAILED), "max_delta_b": stat(b_diffs_abs.max().item(), VERBOSITY_DETAILED), "mean_z": stat(z_scores.mean().item(), VERBOSITY_DEBUG), "rms_z": stat(torch.sqrt((z_scores**2).mean()).item(), VERBOSITY_DETAILED), }