torchref.refinement.targets.similarity module

Coordinate Similarity Target for Difference Refinement

Implements a spike-and-slab prior on per-atom displacements between dark and light models. The loss is quadratic for small displacements (likely noise) and completely flat for large displacements (likely genuine conformational changes).

Per-atom coordinate uncertainty sigma is derived from B-factors:

sigma = sqrt(B / 8*pi^2)

Reference: design_doc_sim_loss.md

class torchref.refinement.targets.similarity.CoordinateSimilarityTarget(model_dark=None, model_light=None, alpha=2.0, verbose=0)[source]

Bases: Target

Spike-and-slab similarity restraint between dark and light models.

For each atom, two hypotheses are considered:

  • Static (prob 1-p): atom did not move, displacement is noise

  • Moved (prob p): atom genuinely displaced

The loss is the negative log marginal likelihood:

L(d) = -logsumexp(-d^2/(2*sigma^2) + alpha, 0)

where d = ||xyz_light - xyz_dark|| and sigma = sqrt(B / 8*pi^2) is the per-atom coordinate uncertainty from B-factors.

Gradient: d/sigma^2 * sigmoid(-d^2/(2*sigma^2) + alpha) This is an L2 restraint weighted by the posterior probability that the atom is static.

Behavior: - d << sigma: ~0.5 * d^2 / sigma^2 (quadratic, tight restraint) - d >> sigma: plateaus completely (no penalty for genuine moves) - Crossover at d ~ sigma * sqrt(2*alpha)

Parameters:
  • model_dark (Model) – Dark (ground state) model. B-factors and coordinates are detached.

  • model_light (Model) – Light (excited state) model. Coordinates carry gradients.

  • alpha (float, optional) – Log prior odds of the static hypothesis. Higher values mean stronger denoising. Default is 2.0 (crossover at ~2*sigma).

  • verbose (int, optional) – Verbosity level. Default is 0.

name: str = 'similarity'
__init__(model_dark=None, model_light=None, alpha=2.0, verbose=0)[source]

Initialize target.

Parameters:

verbose (int, optional) – Verbosity level. Default is 0.

property model_dark: Model

Get dark model.

property model_light: Model

Get light model.

property alpha: float

Get alpha as float.

forward()[source]

Compute spike-and-slab similarity loss.

Returns:

Scalar mean loss over all matched atom pairs.

Return type:

torch.Tensor

stats()[source]

Get similarity restraint statistics.