"""
Coordinate Similarity Target for Difference Refinement
Implements a spike-and-slab prior on per-atom displacements between
dark and light models. The loss is quadratic for small displacements
(likely noise) and completely flat for large displacements (likely
genuine conformational changes).
Per-atom coordinate uncertainty sigma is derived from B-factors:
sigma = sqrt(B / 8*pi^2)
Reference: design_doc_sim_loss.md
"""
import torch
from typing import TYPE_CHECKING, Dict
from .base import Target
from torchref.utils.stats import (
VERBOSITY_DEBUG,
VERBOSITY_DETAILED,
VERBOSITY_STANDARD,
StatEntry,
stat,
)
if TYPE_CHECKING:
from torchref.model.model import Model
[docs]
class CoordinateSimilarityTarget(Target):
"""
Spike-and-slab similarity restraint between dark and light models.
For each atom, two hypotheses are considered:
- **Static** (prob 1-p): atom did not move, displacement is noise
- **Moved** (prob p): atom genuinely displaced
The loss is the negative log marginal likelihood:
L(d) = -logsumexp(-d^2/(2*sigma^2) + alpha, 0)
where d = ||xyz_light - xyz_dark|| and sigma = sqrt(B / 8*pi^2)
is the per-atom coordinate uncertainty from B-factors.
Gradient: d/sigma^2 * sigmoid(-d^2/(2*sigma^2) + alpha)
This is an L2 restraint weighted by the posterior probability
that the atom is static.
Behavior:
- d << sigma: ~0.5 * d^2 / sigma^2 (quadratic, tight restraint)
- d >> sigma: plateaus completely (no penalty for genuine moves)
- Crossover at d ~ sigma * sqrt(2*alpha)
Parameters
----------
model_dark : Model
Dark (ground state) model. B-factors and coordinates are detached.
model_light : Model
Light (excited state) model. Coordinates carry gradients.
alpha : float, optional
Log prior odds of the static hypothesis. Higher values mean
stronger denoising. Default is 2.0 (crossover at ~2*sigma).
verbose : int, optional
Verbosity level. Default is 0.
"""
name: str = "similarity"
[docs]
def __init__(
self,
model_dark: "Model" = None,
model_light: "Model" = None,
alpha: float = 2.0,
verbose: int = 0,
):
super().__init__(verbose=verbose)
self.add_module("_model_dark", model_dark)
self.add_module("_model_light", model_light)
self.register_buffer("_alpha", torch.tensor(alpha))
if model_dark is not None and model_light is not None:
self._build_atom_map()
@property
def model_dark(self) -> "Model":
"""Get dark model."""
return self._model_dark
@property
def model_light(self) -> "Model":
"""Get light model."""
return self._model_light
@property
def alpha(self) -> float:
"""Get alpha as float."""
return self._alpha.item()
@alpha.setter
def alpha(self, value: float):
"""Set alpha."""
self._alpha.fill_(value)
def _build_atom_map(self):
"""Match atoms between dark and light models by identity.
Creates index arrays mapping corresponding atoms between the two
models based on (chainid, resseq, icode, name, altloc) keys.
"""
import pandas as pd
import warnings
pdb_dark = self._model_dark.pdb.copy()
pdb_light = self._model_light.pdb.copy()
# Build unique atom key
for df in (pdb_dark, pdb_light):
df["_key"] = (
df["chainid"].astype(str)
+ "_"
+ df["resseq"].astype(str)
+ "_"
+ df["icode"].astype(str).str.strip()
+ "_"
+ df["name"].astype(str).str.strip()
+ "_"
+ df["altloc"].astype(str).str.strip()
)
# Add integer index columns
pdb_dark["_idx"] = range(len(pdb_dark))
pdb_light["_idx"] = range(len(pdb_light))
# Inner merge to find matching atoms
merged = pd.merge(
pdb_dark[["_key", "_idx"]],
pdb_light[["_key", "_idx"]],
on="_key",
suffixes=("_dark", "_light"),
)
n_matched = len(merged)
n_dark = len(pdb_dark)
n_light = len(pdb_light)
if n_matched == 0:
warnings.warn(
"CoordinateSimilarityTarget: no matching atoms between "
"dark and light models"
)
self.register_buffer(
"_idx_dark", torch.zeros(0, dtype=torch.long)
)
self.register_buffer(
"_idx_light", torch.zeros(0, dtype=torch.long)
)
return
match_rate = n_matched / min(n_dark, n_light)
if match_rate < 0.9:
warnings.warn(
f"CoordinateSimilarityTarget: only {n_matched}/{min(n_dark, n_light)} "
f"atoms matched ({match_rate:.0%})"
)
if self.verbose >= 1:
print(
f" Similarity target: {n_matched} matched atoms "
f"(dark={n_dark}, light={n_light})"
)
self.register_buffer(
"_idx_dark",
torch.tensor(merged["_idx_dark"].values, dtype=torch.long),
)
self.register_buffer(
"_idx_light",
torch.tensor(merged["_idx_light"].values, dtype=torch.long),
)
[docs]
def forward(self) -> torch.Tensor:
"""Compute spike-and-slab similarity loss.
Returns
-------
torch.Tensor
Scalar mean loss over all matched atom pairs.
"""
if len(self._idx_dark) == 0:
device = self._alpha.device
return torch.tensor(0.0, device=device)
xyz_dark = self._model_dark.xyz()
xyz_light = self._model_light.xyz()
# Select matched atoms; detach dark (frozen reference)
pos_dark = xyz_dark[self._idx_dark].detach()
pos_light = xyz_light[self._idx_light]
# Per-atom squared displacement
delta_sq = (pos_light - pos_dark).pow(2).sum(dim=-1)
# Per-atom sigma^2 from dark model B-factors
B = self._model_dark.adp()[self._idx_dark].detach()
sigma_sq = B / (8.0 * torch.pi**2)
sigma_sq = torch.clamp(sigma_sq, min=1e-4)
# Spike-and-slab: -logsumexp(-delta^2/(2*sigma^2) + alpha, 0)
z_static = -0.5 * delta_sq / sigma_sq + self._alpha
loss = -torch.logaddexp(z_static, torch.zeros_like(z_static))
return loss.sum()
[docs]
def stats(self) -> Dict[str, StatEntry]:
"""Get similarity restraint statistics."""
if len(self._idx_dark) == 0:
return {}
with torch.no_grad():
xyz_dark = self._model_dark.xyz()
xyz_light = self._model_light.xyz()
pos_dark = xyz_dark[self._idx_dark]
pos_light = xyz_light[self._idx_light]
diff = pos_light - pos_dark
delta_sq = (diff**2).sum(dim=-1)
distances = torch.sqrt(delta_sq + 1e-8)
B = self._model_dark.adp()[self._idx_dark]
sigma_sq = B / (8.0 * torch.pi**2)
sigma_sq = torch.clamp(sigma_sq, min=1e-4)
sigma = torch.sqrt(sigma_sq)
# Posterior P(static) = sigmoid(-delta^2/(2*sigma^2) + alpha)
p_static = torch.sigmoid(-0.5 * delta_sq / sigma_sq + self._alpha)
n_moved = (p_static < 0.5).sum().item()
loss = self.forward()
return {
"loss": stat(loss.item(), VERBOSITY_STANDARD),
"n_matched": stat(len(self._idx_dark), VERBOSITY_DEBUG),
"n_moved": stat(n_moved, VERBOSITY_DETAILED),
"rms_dist": stat(
torch.sqrt((distances**2).mean()).item(), VERBOSITY_DETAILED
),
"mean_dist": stat(distances.mean().item(), VERBOSITY_DETAILED),
"max_dist": stat(distances.max().item(), VERBOSITY_DETAILED),
"mean_sigma": stat(sigma.mean().item(), VERBOSITY_DETAILED),
"alpha": stat(self._alpha.item(), VERBOSITY_DEBUG),
}