Source code for torchref.refinement.targets.adp.rigid_bond

import numpy as np
import torch
from typing import TYPE_CHECKING, Dict

from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import ADPTarget

if TYPE_CHECKING:
    from torchref.model.model import Model


[docs] class RigidBondTarget(ADPTarget): """ Rigid Bond restraint (DELU in SHELX, Hirshfeld test). Based on Hirshfeld's rigid bond test (Acta Cryst. A32, 239, 1976). For a truly rigid bond, the mean-square displacement amplitudes (MSDA) of the two bonded atoms along the bond direction should be equal. This is because in a rigid bond, the atoms move together. For anisotropic ADPs (U tensors):: z_12 = l_12^T U_1 l_12 / |l_12|² (MSDA of atom 1 along bond) z_21 = l_21^T U_2 l_21 / |l_21|² (MSDA of atom 2 along bond) Δz = z_12 - z_21 should be ~0 For isotropic B-factors, the difference in B_iso is used as a proxy:: ΔB = B_1 - B_2 This differs from SIMU (ADPSimilarityTarget) which restrains the full ADP tensors to be similar. Rigid bond only restrains the component along the bond direction. Energy: E = w * Δz² NLL: NLL = 0.5 * (Δz / σ)² + log(σ) + 0.5 * log(2π) References ---------- - Hirshfeld, F.L. (1976). Acta Cryst. A32, 239. - cctbx/adp_restraints/rigid_bond.h Parameters ---------- model : Model Reference to Model object. sigma : float, optional Target standard deviation for Δz. Default is 0.004 Ų. Hirshfeld found typical values of 0.001 Ų for good structures. use_aniso : bool, optional If True and model has anisotropic ADPs, use proper tensor calculation. Default is True. verbose : int, optional Verbosity level. Default is 0. """ name: str = "adp/delu"
[docs] def __init__( self, model: "Model" = None, sigma: float = 0.004, use_aniso: bool = True, verbose: int = 0, ): super().__init__(model, verbose) self.sigma = sigma self.use_aniso = use_aniso
[docs] def forward(self) -> torch.Tensor: """ Compute rigid bond restraint. For isotropic refinement, uses B-factor differences along bonds. For anisotropic refinement, computes proper MSDA differences. """ device = self.model.xyz().device # Check if model has anisotropic ADPs has_aniso = hasattr(self.model, "u_aniso") and self.model.u_aniso is not None if has_aniso and self.use_aniso: return self._compute_aniso_rigid_bond() else: return self._compute_iso_rigid_bond()
def _compute_iso_rigid_bond(self) -> torch.Tensor: """ Compute rigid bond restraint for isotropic B-factors. For isotropic ADPs, the MSDA along any direction is B/(8π²). So the difference in MSDA is proportional to ΔB. We use ΔB directly and scale sigma accordingly. """ adp = self.model.adp() xyz = self.model.xyz() device = xyz.device delta_z_list = [] if "bond" not in self.restraints.restraints: return torch.tensor(0.0, device=device) for origin, restraint_group in self.restraints.restraints["bond"].items(): if origin == "all": continue indices = restraint_group.get("indices") if indices is not None and len(indices) > 0: adp1 = adp[indices[:, 0]] adp2 = adp[indices[:, 1]] # For isotropic ADPs: U_iso = B / (8π²) # MSDA along bond = U_iso (same in all directions) # Δz = (B1 - B2) / (8π²) delta_z = (adp1 - adp2) / (8.0 * np.pi**2) delta_z_list.append(delta_z) if not delta_z_list: return torch.tensor(0.0, device=device) delta_z = torch.cat(delta_z_list, dim=0) # Gaussian NLL log_2pi = torch.log(torch.tensor(2.0 * np.pi, device=device, dtype=xyz.dtype)) nll = 0.5 * (delta_z / self.sigma) ** 2 + np.log(self.sigma) + 0.5 * log_2pi return nll.sum() def _compute_aniso_rigid_bond(self) -> torch.Tensor: """ Compute rigid bond restraint for anisotropic ADPs. For each bond: l = (r2 - r1) / |r2 - r1| (unit vector along bond) z_12 = l^T U_1 l (MSDA of atom 1 along bond direction) z_21 = l^T U_2 l (MSDA of atom 2 along bond direction) Δz = z_12 - z_21 The U tensor is symmetric 3x3, stored as 6 unique values: U = [[U11, U12, U13], [U12, U22, U23], [U13, U23, U33]] """ xyz = self.model.xyz() device = xyz.device # Get anisotropic U tensors (N, 6) -> (N, 3, 3) u_aniso = self.model.u_aniso # Shape: (N, 6) for U11,U22,U33,U12,U13,U23 # Convert to full symmetric matrices n_atoms = u_aniso.shape[0] U = torch.zeros(n_atoms, 3, 3, device=device, dtype=xyz.dtype) U[:, 0, 0] = u_aniso[:, 0] # U11 U[:, 1, 1] = u_aniso[:, 1] # U22 U[:, 2, 2] = u_aniso[:, 2] # U33 U[:, 0, 1] = u_aniso[:, 3] # U12 U[:, 1, 0] = u_aniso[:, 3] # U12 (symmetric) U[:, 0, 2] = u_aniso[:, 4] # U13 U[:, 2, 0] = u_aniso[:, 4] # U13 (symmetric) U[:, 1, 2] = u_aniso[:, 5] # U23 U[:, 2, 1] = u_aniso[:, 5] # U23 (symmetric) delta_z_list = [] if "bond" not in self.restraints.restraints: return torch.tensor(0.0, device=device) for origin, restraint_group in self.restraints.restraints["bond"].items(): if origin == "all": continue indices = restraint_group.get("indices") if indices is not None and len(indices) > 0: idx1 = indices[:, 0] idx2 = indices[:, 1] # Get positions and compute bond vectors r1 = xyz[idx1] # (n_bonds, 3) r2 = xyz[idx2] # (n_bonds, 3) bond_vec = r2 - r1 # (n_bonds, 3) # Add small epsilon to prevent division by zero gradient issues bond_length = torch.sqrt((bond_vec**2).sum(dim=-1, keepdim=True) + 1e-8) l = bond_vec / bond_length # Unit vector along bond # Get U tensors for bonded atoms U1 = U[idx1] # (n_bonds, 3, 3) U2 = U[idx2] # (n_bonds, 3, 3) # Compute MSDA along bond direction: z = l^T U l # For batch: z = sum_ij l_i U_ij l_j # Using einsum: z = einsum('bi,bij,bj->b', l, U, l) z_12 = torch.einsum("bi,bij,bj->b", l, U1, l) z_21 = torch.einsum("bi,bij,bj->b", l, U2, l) delta_z = z_12 - z_21 delta_z_list.append(delta_z) if not delta_z_list: return torch.tensor(0.0, device=device) delta_z = torch.cat(delta_z_list, dim=0) # Gaussian NLL log_2pi = torch.log(torch.tensor(2.0 * np.pi, device=device, dtype=xyz.dtype)) nll = 0.5 * (delta_z / self.sigma) ** 2 + np.log(self.sigma) + 0.5 * log_2pi return nll.sum()
[docs] def get_delta_z_stats(self) -> Dict[str, float]: """ Get statistics of Δz values for analysis. Returns ------- dict Dictionary with mean, std, max, min of |Δz| values and Z-scores. """ adp = self.model.adp() xyz = self.model.xyz() device = xyz.device delta_z_list = [] if "bond" not in self.restraints.restraints: return { "count": 0, "mean": 0.0, "std": 0.0, "max": 0.0, "min": 0.0, "rms": 0.0, "mean_z": 0.0, "rms_z": 0.0, } for origin, restraint_group in self.restraints.restraints["bond"].items(): if origin == "all": continue indices = restraint_group.get("indices") if indices is not None and len(indices) > 0: adp1 = adp[indices[:, 0]] adp2 = adp[indices[:, 1]] delta_z = (adp1 - adp2) / (8.0 * np.pi**2) delta_z_list.append(delta_z) if not delta_z_list: return { "count": 0, "mean": 0.0, "std": 0.0, "max": 0.0, "min": 0.0, "rms": 0.0, "mean_z": 0.0, "rms_z": 0.0, } delta_z_all = torch.cat(delta_z_list, dim=0) delta_z_abs = delta_z_all.abs() # Z-scores (deviation / sigma) z_scores = delta_z_abs / self.sigma return { "count": len(delta_z_all), "mean": delta_z_abs.mean().item(), "std": delta_z_all.std().item(), "max": delta_z_abs.max().item(), "min": delta_z_abs.min().item(), "rms": torch.sqrt((delta_z_all**2).mean()).item(), "mean_z": z_scores.mean().item(), "rms_z": torch.sqrt((z_scores**2).mean()).item(), }
[docs] def stats(self) -> Dict[str, any]: """ Get rigid bond restraint statistics. Returns statistics including Δz values along bonds. """ delta_z_stats = self.get_delta_z_stats() if delta_z_stats.get("count", 0) == 0: return {} loss = self.forward() return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "count": stat(delta_z_stats["count"], VERBOSITY_DEBUG), "rms": stat(delta_z_stats["rms"], VERBOSITY_DETAILED), "mean": stat(delta_z_stats["mean"], VERBOSITY_DETAILED), "max": stat(delta_z_stats["max"], VERBOSITY_DETAILED), "rms_z": stat(delta_z_stats["rms_z"], VERBOSITY_DETAILED), "std": stat(delta_z_stats["std"], VERBOSITY_DEBUG), "min": stat(delta_z_stats["min"], VERBOSITY_DEBUG), "mean_z": stat(delta_z_stats["mean_z"], VERBOSITY_DEBUG), }