Source code for torchref.refinement.targets.geometry.non_bonded

import numpy as np
import torch
from typing import TYPE_CHECKING, Dict, Tuple

from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import GeometryTarget

if TYPE_CHECKING:
    from torchref.model.model import Model


[docs] class NonBondedTarget(GeometryTarget): r""" Non-bonded (van der Waals) restraint target using PROLSQ-style repulsion, parameterized as a generalized-Gaussian NLL on the overlap with scale :math:`\sigma`. Per-pair NLL (``mode='prolsq'``): .. math:: \mathrm{NLL}(v) \;=\; \frac{v^{p}}{p\,\sigma^{p}} \;+\; \log\sigma \;+\; \tfrac{1}{2}\log(2\pi) \qquad v \;=\; \max\!\bigl(0,\, d_{\text{vdw}} + b - d\bigr) where :math:`p = r_\text{exp}` (default 4). The shape term :math:`v^p/(p\sigma^p)` is algebraically identical to the classical PROLSQ energy :math:`c_{\text{rep}}\,v^p` with :math:`c_{\text{rep}} = 1/(p\,\sigma^{p})`; exposing :math:`\sigma` makes the physics legible (σ is an "effective tolerance" on the overlap) and puts the VDW loss on the same NLL footing as bond / angle / planarity. **Default sigma = 0.3 Å.** A practical middle ground: stiff enough to reliably pull medium-large clashes (~0.4–0.5 Å) out of the refinement but loose enough that starting from a shaken or rough model doesn't generate LBFGS-destabilising gradients. A 0.4 Å MolProbity clash sits at ~1.3σ and contributes ~0.8 NLL units; a 0.5 Å severe clash sits at ~1.7σ and contributes ~1.9 NLL units. The classical PROLSQ strength ``c_rep=16, r_exp=4`` is equivalent to :math:`\sigma \approx 0.354\ \text{\AA}`, very close to this default. Set ``sigma`` explicitly for deliberate tightening (e.g. 0.13 Å → 3σ clash, ~20 NLL units) or loosening. Alternative modes: - ``'prolsq'``: generalized-Gaussian NLL with exponent ``r_exp`` (default) - ``'gaussian'``: Gaussian NLL on the overlap using per-pair sigmas - ``'soft'``: soft repulsion with linear core outside ``threshold`` When symmetry information is available (cell and spacegroup on the model), also handles contacts between ASU atoms and symmetry-related copies. Symmetry mate positions are recomputed on-the-fly from current ASU coordinates so that gradients flow to both atoms in each pair. Reference: cctbx/geometry_restraints/nonbonded.h, PROLSQ documentation, MolProbity clash criterion (Davis et al., NAR 2007). Parameters ---------- model : Model, optional Reference to Model object. mode : str, optional Repulsion function type ('prolsq', 'gaussian', 'soft'). Default is 'prolsq'. sigma : float, optional Effective tolerance on the overlap in Angstroms. Default is 0.3. r_exp : float, optional Exponent of the repulsion term. Default is 4.0. c_rep : float, optional Back-door repulsion coefficient. If provided, overrides the sigma-derived value and the NLL becomes :math:`c_{\text{rep}} v^{r_\text{exp}} + \log\sigma + \tfrac{1}{2}\log(2\pi)`. Useful for reproducing legacy PROLSQ weights. Default is None (derive from ``sigma``). buffer : float, optional Distance buffer in Angstroms added to VDW radii sum. Shifts the repulsion onset outward so atoms feel repulsion before they clash. Default is 0.0. verbose : int, optional Verbosity level. Default is 0. """ name: str = "geometry/nonbonded"
[docs] def __init__( self, model: "Model" = None, mode: str = "prolsq", sigma: float = 0.3, r_exp: float = 4.0, c_rep: "float | None" = None, buffer: float = 0.0, rebuild_threshold: float = 1.0, verbose: int = 0, scale: float = 10.0, ): """ Initialize non-bonded target. Parameters ---------- model : Model, optional Reference to Model object. mode : str, optional Repulsion function type ('prolsq', 'gaussian', 'soft'). Default is 'prolsq'. sigma : float, optional Effective tolerance on the overlap (Å). Default 0.3. Only used when ``c_rep`` is None. r_exp : float, optional Repulsion exponent. Default is 4.0. c_rep : float or None, optional Legacy coefficient override. If None (default), derived from ``sigma`` as ``1 / (r_exp * sigma ** r_exp)``. buffer : float, optional Distance buffer in Angstroms added to VDW radii sum. Default is 0.0. rebuild_threshold : float, optional Maximum ASU atom displacement in Angstroms since the last VDW pair-list build before :meth:`maintenance` triggers a rebuild. Default is 1.0 Å — well inside the ~2.4 Å safety margin of the default 6.0 Å cutoff, so newly-formed contacts cannot slip through the list. verbose : int, optional Verbosity level. Default is 0. """ super().__init__(model, verbose, target_value=0.5, sigma=1.2) self.mode = mode self.scale = scale # Register sigma / r_exp / buffer as buffers so .to(device) moves them. self.register_buffer("_sigma_vdw", torch.tensor(float(sigma))) self.register_buffer("_r_exp", torch.tensor(float(r_exp))) self.register_buffer("_buffer", torch.tensor(float(buffer))) self.register_buffer( "_rebuild_threshold", torch.tensor(float(rebuild_threshold)) ) # c_rep: back-door override; by default derived from sigma so that # PROLSQ shape term equals v^p / (p * sigma^p). if c_rep is None: c_rep_val = 1.0 / (float(r_exp) * float(sigma) ** float(r_exp)) else: c_rep_val = float(c_rep) self.register_buffer("_c_rep", torch.tensor(c_rep_val))
@property def c_rep(self) -> float: """Get repulsion coefficient.""" return self._c_rep.item() @c_rep.setter def c_rep(self, value: float): """Set repulsion coefficient. Note: this breaks the internal link ``c_rep = 1 / (r_exp * sigma**r_exp)``. After setting c_rep directly, ``sigma_vdw`` should be treated as informational only for the log(sigma) term; the shape term uses c_rep. """ self._c_rep.fill_(value) @property def sigma_vdw(self) -> float: """Get the effective overlap tolerance sigma (Å).""" return self._sigma_vdw.item() @sigma_vdw.setter def sigma_vdw(self, value: float): """Set sigma and recompute the linked ``c_rep``.""" self._sigma_vdw.fill_(value) new_c_rep = 1.0 / (self._r_exp.item() * value ** self._r_exp.item()) self._c_rep.fill_(new_c_rep) @property def r_exp(self) -> float: """Get repulsion exponent.""" return self._r_exp.item() @r_exp.setter def r_exp(self, value: float): """Set repulsion exponent.""" self._r_exp.fill_(value) @property def buffer(self) -> float: """Get distance buffer.""" return self._buffer.item() @buffer.setter def buffer(self, value: float): """Set distance buffer.""" self._buffer.fill_(value)
[docs] def maintenance(self) -> None: """Rebuild the VDW pair list if any ASU atom drifted too far. Fast path: one ``max().item()`` sync on the per-atom displacement norm between the current ASU coordinates and the snapshot taken at the last VDW build. If the max displacement stays within ``_rebuild_threshold`` we return immediately. Slow path (only when triggered): delegate to ``restraints.rebuild_vdw_restraints`` which refreshes the pair list using the original build kwargs and updates the snapshot. See :meth:`Target.maintenance` for the general contract. Safety invariant: the default build cutoff (6.0 Å) leaves roughly ``cutoff - max_vdw_sum ≈ 2.4 Å`` of slack before a previously- non-contact atom pair could form a new clash. Setting ``rebuild_threshold < 2.4 / 2 = 1.2 Å`` guarantees that no such pair can slip through the list — that is, a rebuild fires before the slack is consumed. """ if self._model is None: return r = self.restraints if r is None: return snapshot = getattr(r, "_last_vdw_build_xyz", None) if snapshot is None: return with torch.no_grad(): delta = self._model.xyz() - snapshot max_disp_sq = (delta * delta).sum(dim=-1).max() thresh_sq = self._rebuild_threshold * self._rebuild_threshold if max_disp_sq.item() <= thresh_sq.item(): return # within slack — nothing to do if self.verbose > 0: max_disp = float(max_disp_sq.item()) ** 0.5 thresh = float(self._rebuild_threshold.item()) print( f" VDW rebuild: max drift {max_disp:.2f} Å > " f"threshold {thresh:.2f} Å" ) r.rebuild_vdw_restraints()
def _compute_positions( self, xyz: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute atom positions for all VDW pairs, handling symmetry mates. For intra-ASU pairs (symop=0, offset=0), the identity transform is applied which reduces to a direct lookup. For symmetry pairs, the mate position is recomputed on-the-fly through the symmetry transformation so that gradients flow to both atoms. All pairs are processed in a single vectorized pass. Parameters ---------- xyz : torch.Tensor Current ASU Cartesian coordinates of shape (N, 3). Returns ------- pos1 : torch.Tensor Positions of first atom in each pair (N_pairs, 3). pos2 : torch.Tensor Positions of second atom in each pair (N_pairs, 3). min_distances : torch.Tensor VDW distance threshold for each pair (N_pairs,). """ vdw_data = self.restraints.restraints["vdw"] indices = vdw_data["indices"] min_distances = vdw_data["min_distances"] symop_indices = vdw_data.get("symop_indices") cell_offsets = vdw_data.get("cell_offsets") # pos1 is always a direct ASU lookup pos1 = xyz[indices[:, 0]] has_symmetry = ( symop_indices is not None and len(symop_indices) > 0 and not (symop_indices == 0).all() ) if not has_symmetry: # Fast path: all pairs are intra-ASU pos2 = xyz[indices[:, 1]] return pos1, pos2, min_distances # Unified path: apply symmetry transform to all pos2 atoms. # For intra-ASU pairs (symop=0, offset=0) this is the identity. cell = self.model.cell sg = self.model.symmetry # Gather mate source coordinates and convert to fractional mate_source = xyz[indices[:, 1]] # (N_pairs, 3) -- gradients flow frac = cell.cartesian_to_fractional(mate_source) # Gather per-pair rotation matrices and translations R = sg.matrices[symop_indices].to(frac.dtype) # (N_pairs, 3, 3) t = sg.translations[symop_indices].to(frac.dtype) # (N_pairs, 3) offsets = cell_offsets.to(frac.dtype) # (N_pairs, 3) # Batched symmetry transform: R @ frac + t + offset frac_transformed = ( torch.bmm(R, frac.unsqueeze(-1)).squeeze(-1) + t + offsets ) # Convert back to Cartesian pos2 = cell.fractional_to_cartesian(frac_transformed) return pos1, pos2, min_distances
[docs] def forward(self) -> torch.Tensor: from torchref.base.targets.nonbonded import nonbonded_heavy_math xyz = self.model.xyz() device = xyz.device if "vdw" not in self.restraints.restraints: return torch.tensor(0.0, device=device) vdw_data = self.restraints.restraints["vdw"] indices = vdw_data.get("indices") if indices is None or len(indices) == 0: return torch.tensor(0.0, device=device) sigmas = vdw_data["sigmas"] # The prolsq branch goes through the math dispatcher — Triton on # CUDA fp32, eager otherwise. Other modes (gaussian, soft) keep # the inline path below. if self.mode == "prolsq": return nonbonded_heavy_math( xyz, indices, vdw_data["min_distances"], vdw_data.get("symop_indices"), vdw_data.get("cell_offsets"), self.model.symmetry.matrices, self.model.symmetry.translations, self.model.cell.fractional_matrix, self.model.cell.inv_fractional_matrix, self._c_rep, self._r_exp, float(self._buffer), self._sigma_vdw, ) # Compute positions (handles symmetry transparently) pos1, pos2, min_distances = self._compute_positions(xyz) # Compute actual distances with small epsilon to prevent gradient issues at d=0 diff = pos2 - pos1 actual_distances = torch.sqrt((diff**2).sum(dim=-1) + 1e-8) # Violations: where actual distance is less than VDW sum + buffer violations = torch.clamp(min_distances + self._buffer - actual_distances, min=0.0) if self.mode == "gaussian": log_2pi = torch.log( torch.tensor(2.0 * np.pi, device=device, dtype=xyz.dtype) ) nll = 0.5 * (violations / sigmas) ** 2 + torch.log(sigmas) + 0.5 * log_2pi return nll.sum() elif self.mode == "soft": threshold = 0.5 # Å - switch to linear below this quadratic_mask = violations <= threshold quadratic_energy = self._c_rep * (violations**2) linear_energy = self._c_rep * (2 * threshold * violations - threshold**2) energy = torch.where(quadratic_mask, quadratic_energy, linear_energy) return energy.sum() else: raise ValueError(f"Unknown non-bonded mode: {self.mode}")
[docs] def get_violations(self, threshold: float = 0.0) -> Dict[str, torch.Tensor]: """ Get information about VDW violations. Parameters ---------- threshold : float, optional Only report violations greater than this (Å). Default is 0.0. Returns ------- dict Dictionary with 'indices', 'violations', 'distances', 'min_distances'. """ xyz = self.model.xyz() device = xyz.device if "vdw" not in self.restraints.restraints: return { "indices": torch.tensor([], dtype=torch.long, device=device).reshape( 0, 2 ), "violations": torch.tensor([], device=device), "distances": torch.tensor([], device=device), "min_distances": torch.tensor([], device=device), } vdw_data = self.restraints.restraints["vdw"] indices = vdw_data["indices"] if indices is None or len(indices) == 0: return { "indices": torch.tensor([], dtype=torch.long, device=device).reshape( 0, 2 ), "violations": torch.tensor([], device=device), "distances": torch.tensor([], device=device), "min_distances": torch.tensor([], device=device), } pos1, pos2, min_distances = self._compute_positions(xyz) actual_distances = torch.norm(pos2 - pos1, dim=-1) violations = torch.clamp(min_distances - actual_distances, min=0.0) # Filter by threshold mask = violations > threshold return { "indices": indices[mask], "violations": violations[mask], "distances": actual_distances[mask], "min_distances": min_distances[mask], }
[docs] def stats(self) -> Dict[str, any]: """Get non-bonded restraint statistics.""" xyz = self.model.xyz() device = xyz.device if "vdw" not in self.restraints.restraints: return {} vdw_data = self.restraints.restraints["vdw"] indices = vdw_data.get("indices") if indices is None or len(indices) == 0: return {} sigmas = vdw_data["sigmas"] pos1, pos2, min_distances = self._compute_positions(xyz) actual_distances = torch.norm(pos2 - pos1, dim=-1) # Violations: where actual distance < VDW sum violations = torch.clamp(min_distances - actual_distances, min=0.0) n_violations = (violations > 0).sum().item() # RMS of violations only (for those that clash) if n_violations > 0: violation_mask = violations > 0 rms_violation = torch.sqrt((violations[violation_mask] ** 2).mean()).item() max_violation = violations.max().item() else: rms_violation = 0.0 max_violation = 0.0 loss = self.forward() result = { "loss": stat(loss.item(), VERBOSITY_STANDARD), "n": stat(len(indices), VERBOSITY_DEBUG), "n_violations": stat(n_violations, VERBOSITY_DETAILED), "rms_violation": stat(rms_violation, VERBOSITY_DETAILED), "max_violation": stat(max_violation, VERBOSITY_DEBUG), "mean_sigma": stat(sigmas.mean().item(), VERBOSITY_DEBUG), } # Report symmetry contact count if available symop_indices = vdw_data.get("symop_indices") cell_offsets = vdw_data.get("cell_offsets") if symop_indices is not None and len(symop_indices) > 0: is_sym = (symop_indices != 0) | (cell_offsets != 0).any(dim=-1) n_sym = is_sym.sum().item() if n_sym > 0: result["n_symmetry"] = stat(n_sym, VERBOSITY_DETAILED) return result