Source code for torchref.refinement.targets.combined

"""

Combined targets for refinement (e.g., geometry + ADP).

Integrate multiple component targets into a single combined target
using nn.ModuleDict for clean organization and easy access.

Integrate into lossState via add_to_state"""

from typing import TYPE_CHECKING, Dict

import torch
from torch import nn

from torchref.config import get_default_device
from torchref.refinement.targets.base import Target, ModelTarget
from torchref.refinement.targets.geometry import (
    BondTarget, AngleTarget, TorsionTarget, PlanarityTarget,
    ChiralTarget, NonBondedHTarget, RamachandranTarget,
)
from torchref.refinement.targets.adp import (
    ADPSimilarityTarget, ADPLocalityTarget, ADPEntropyTarget,
)
from torchref.utils.stats import (
    VERBOSITY_DETAILED,
    filter_stats,
)

if TYPE_CHECKING:
    from torchref.model.model import Model
    from torchref.refinement.base_refinement import Refinement


[docs] class CombinedTargets(Target): """ Base class for combined targets. Uses nn.ModuleDict to store component targets for clean organization and easy access via dictionary-style notation. Subclasses should override `_create_targets()` to define their component targets. Parameters ---------- verbose : int, optional Verbosity level. Default is 0. Attributes ---------- _targets : nn.ModuleDict Dictionary of component targets. """
[docs] def __init__(self, verbose: int = 0): """ Initialize CombinedTargets. Parameters ---------- verbose : int, optional Verbosity level. Default is 0. """ super().__init__(verbose=verbose) self._targets = nn.ModuleDict(self._create_targets())
def _create_targets(self) -> Dict[str, "Target"]: """ Create and return component targets as a dictionary. Subclasses must override this method to define their component targets. Returns ------- Dict[str, Target] Dictionary mapping target names to Target instances. """ raise NotImplementedError("Subclasses must implement _create_targets() method.")
[docs] def targets(self) -> nn.ModuleDict: """Return registered sub-targets as ModuleDict.""" return self._targets
[docs] def __getitem__(self, key: str) -> "Target": """Get a target by name using dictionary-style access.""" return self._targets[key]
[docs] def __contains__(self, key: str) -> bool: """Check if a target exists.""" return key in self._targets
[docs] def keys(self): """Return target names.""" return self._targets.keys()
[docs] def values(self): """Return target instances.""" return self._targets.values()
[docs] def items(self): """Return (name, target) pairs.""" return self._targets.items()
[docs] def target_losses(self) -> Dict[str, torch.Tensor]: """Get individual component losses (without weights).""" return {name: target() for name, target in self._targets.items()}
[docs] def forward(self) -> torch.Tensor: """Compute total combined target loss.""" losses = list(self.target_losses().values()) if not losses: return torch.tensor(0.0) return torch.stack(losses).sum()
[docs] def stats(self) -> Dict[str, any]: """Get statistics from all registered targets.""" statistics = {} for name, target in self._targets.items(): if hasattr(target, "stats"): target_stats = target.stats() if target_stats: statistics[name] = target_stats return statistics
[docs] def get(self) -> dict: """Get individual component losses.""" return self.target_losses()
[docs] def add_to_state(self, state): for name, target in self._targets.items(): target.add_to_state(state) return state
[docs] class CombinedModelTargets(ModelTarget): """ Base class for combined targets that only need Model (geometry/ADP targets). Uses nn.ModuleDict to store component targets for clean organization and easy access via dictionary-style notation. Subclasses should override `_create_targets()` to define their component targets. Parameters ---------- model : Model, optional Reference to Model object. verbose : int, optional Verbosity level. Default is 0. Attributes ---------- _targets : nn.ModuleDict Dictionary of component targets. """
[docs] def __init__(self, model: "Model" = None, verbose: int = 0): """ Initialize CombinedModelTargets. Parameters ---------- model : Model, optional Reference to Model object. verbose : int, optional Verbosity level. Default is 0. """ super().__init__(model, verbose) self._targets = nn.ModuleDict(self._create_targets())
def _create_targets(self) -> Dict[str, "Target"]: """ Create and return component targets as a dictionary. Subclasses must override this method to define their component targets. Returns ------- Dict[str, Target] Dictionary mapping target names to Target instances. """ raise NotImplementedError("Subclasses must implement _create_targets() method.")
[docs] def targets(self) -> nn.ModuleDict: """Return registered sub-targets as ModuleDict.""" return self._targets
[docs] def __getitem__(self, key: str) -> "Target": """Get a target by name using dictionary-style access.""" return self._targets[key]
[docs] def __contains__(self, key: str) -> bool: """Check if a target exists.""" return key in self._targets
[docs] def keys(self): """Return target names.""" return self._targets.keys()
[docs] def values(self): """Return target instances.""" return self._targets.values()
[docs] def items(self): """Return (name, target) pairs.""" return self._targets.items()
[docs] def target_losses(self) -> Dict[str, torch.Tensor]: """Get individual component losses (without weights).""" return {name: target() for name, target in self._targets.items()}
[docs] def forward(self) -> torch.Tensor: """Compute total combined target loss.""" losses = list(self.target_losses().values()) if not losses: return torch.tensor( 0.0, device=self.model.xyz().device if self.model else get_default_device(), ) return torch.stack(losses).sum()
[docs] def stats(self) -> Dict[str, any]: """Get statistics from all registered targets.""" statistics = {} for name, target in self._targets.items(): if hasattr(target, "stats"): target_stats = target.stats() if target_stats: statistics[name] = target_stats return statistics
[docs] def get(self) -> dict: """Get individual component losses.""" return self.target_losses()
[docs] def add_to_state(self, state): for name, target in self._targets.items(): target.add_to_state(state) return state
[docs] class TotalGeometryTarget(CombinedModelTargets): """ Computes weighted sum of all geometry restraint NLLs. Uses nn.ModuleDict to store component targets: - 'bond': BondTarget - 'angle': AngleTarget - 'torsion': TorsionTarget - 'planarity': PlanarityTarget - 'chiral': ChiralTarget - 'nonbonded': NonBondedHTarget (includes riding hydrogen VDW) The torsion weight is reduced because: 1. Protein torsions naturally deviate from ideal (Ramachandran plot) 2. Side chain rotamers have discrete populations, not single ideals 3. High torsion weight can over-constrain the structure The nonbonded weight is very low because: 1. PROLSQ repulsion is already steep (E ~ violation^4) 2. Most contacts should be satisfied by covalent geometry 3. High VDW weight can prevent proper packing Set weight to 0 to disable a component. Parameters ---------- model : Model, optional Reference to Model object. verbose : int, optional Verbosity level. Default is 0. Examples -------- :: geom_target = TotalGeometryTarget(model) loss = geom_target() bond_loss = geom_target['bond']() for name, target in geom_target.items(): print(f"{name}: {target()}") """ def _create_targets(self) -> Dict[str, Target]: """ Create geometry component targets. Returns ------- Dict[str, Target] Dictionary of geometry targets. """ print("Initializing TotalGeometryTarget with component targets...") return { "bond": BondTarget(self.model, self.verbose), "angle": AngleTarget(self.model, self.verbose), "torsion": TorsionTarget(self.model, self.verbose), "planarity": PlanarityTarget(self.model, self.verbose), "chiral": ChiralTarget(self.model, self.verbose), "nonbonded": NonBondedHTarget(self.model, verbose=self.verbose), "ramachandran": RamachandranTarget(self.model, self.verbose), }
[docs] def get_metrics(self, verbosity: int = VERBOSITY_DETAILED) -> Dict[str, float]: """ Get all geometry metrics as a flat dictionary for logging/reporting. Parameters ---------- verbosity : int, optional Verbosity level for filtering. Default is VERBOSITY_DETAILED. Returns ------- dict Dictionary with validation metrics from all component targets. All values are Python floats (not tensors). """ metrics = {} # Total loss (always include) total_loss = self.forward() metrics["geom_total_loss"] = ( total_loss.item() if torch.is_tensor(total_loss) else total_loss ) # Get losses from target_losses() for name, loss in self.target_losses().items(): loss_val = loss.item() if torch.is_tensor(loss) else loss metrics[f"geom_{name}_loss"] = loss_val # Get statistics from parent stats() method and filter here filtered_stats = filter_stats(self.stats(), verbosity) for name, target_stats in filtered_stats.items(): for stat_name, stat_val in target_stats.items(): metrics[f"geom_{name}_{stat_name}"] = stat_val return metrics
[docs] def print_statistics(self): """Print REFMAC-style geometry statistics with losses.""" # Temporarily disable verbose to prevent duplicate output during loss calculation saved_verbose = self.verbose self.verbose = 0 print("\n" + "=" * 90) print("Geometry Restraint Statistics (REFMAC-style)") print("=" * 90) # Show component targets print(f"Components: {', '.join(self._targets.keys())}") print("-" * 90) print( f"{'Restraint Type':<25} {'N':>8} {'RMS Delta':>12} {'RMS Z':>10} {'Av(Sigma)':>12} {'Loss':>12}" ) print("-" * 90) # Get losses and stats using parent methods losses = self.target_losses() all_stats = self.stats() for name, loss in losses.items(): try: loss_val = loss.item() if torch.is_tensor(loss) else loss stats = all_stats.get(name, {}) # Format based on available stats if "n" in stats: n = stats["n"] rms_delta = stats.get("rms_delta", 0.0) rms_z = stats.get("rms_z", 0.0) mean_sigma = stats.get("mean_sigma", 0.0) print( f"{name:<25} {n:>8} {rms_delta:>12.4f} {rms_z:>10.2f} {mean_sigma:>12.4f} {loss_val:>12.4f}" ) elif "n_violations" in stats: # NonBonded format n = stats.get("n", 0) n_viol = stats.get("n_violations", 0) pct_viol = 100.0 * n_viol / n if n > 0 else 0.0 print( f"{name:<25} {n:>8} pairs, {n_viol:>6} violations ({pct_viol:>5.1f}%)" ) print( f"{' RMS violation (Å)':<25} {stats.get('rms_violation', 0.0):>12.4f} Max: {stats.get('max_violation', 0.0):.4f} Å" ) print(f"{' Loss':<25} {loss_val:>12.4f}") else: print( f"{name:<25} {'':>8} {'':>12} {'':>10} {'':>12} {loss_val:>12.4f}" ) except Exception: pass # Total loss print("-" * 90) total_loss = self.forward().item() # Restore verbose now self.verbose = saved_verbose print( f"{'TOTAL GEOMETRY LOSS':<25} {'':>8} {'':>12} {'':>10} {'':>12} {total_loss:>12.4f}" ) print("=" * 90) print("Target: RMS Z should be ~1.0 for well-refined structure") print(" Phenix typical: Bond RMS ~0.007Å, Angle RMS ~1.2°") print("=" * 90 + "\n")
[docs] class TotalADPTarget(CombinedModelTargets): """ Total ADP restraint target combining global, similarity, and local components. Uses nn.ModuleDict to store component targets: - 'simu': ADPSimilarityTarget (SIMU-like bond similarity) - 'locality': ADPLocalityTarget (spatial smoothness) - 'KL': ADPEntropyTarget (KL divergence regularization) B-factors follow a LOG-NORMAL distribution (B > 0, right-skewed). If B ~ LogNormal(μ, σ), then log(B) ~ Normal(μ, σ). This target combines: 1. **Similarity restraint (SIMU-like)**: Bond-based B-factor similarity - Enforces bonded atoms have similar B-factors - Based on covalent bond topology (strongest local constraint) 2. **Locality restraint**: Spatial smoothness - nearby atoms should have similar B - Uses K-NN with distance-based sigma (d² scaling) - Medium-range spatial correlation 3. **KL divergence**: Controls the spread of B-factor distribution - Prevents overfitting by controlling distribution width Log-normal distribution properties: - If log(B) ~ N(μ, σ), then: - Mean of B: exp(μ + σ²/2) - Mode of B: exp(μ - σ²) - For typical proteins: σ_logB ≈ 0.3-0.5 (in log space) Parameters ---------- model : Model Reference to Model object. verbose : int, optional Verbosity level. Default is 0. Examples -------- :: adp_target = TotalADPTarget(model) loss = adp_target() simu_loss = adp_target['simu']() for name, target in adp_target.items(): print(f"{name}: {target()}") """ def _create_targets(self) -> Dict[str, Target]: """ Create ADP component targets. Returns ------- Dict[str, Target] Dictionary of ADP targets. """ print("Initializing TotalADPTarget with component targets...") return { "simu": ADPSimilarityTarget(self.model, verbose=self.verbose), "locality": ADPLocalityTarget( self.model, verbose=self.verbose ), "KL": ADPEntropyTarget(self.model, verbose=self.verbose), }
[docs] def print_statistics(self) -> None: """ Print comprehensive ADP restraint statistics. Displays statistics from all registered ADP targets. """ print("\n" + "=" * 90) print("ADP RESTRAINT STATISTICS") print("=" * 90) # Component losses and statistics print(f"\n{'COMPONENT LOSSES':^90}") print("-" * 90) print(f"{'Component':<25} {'Loss':>15}") print("-" * 90) # Get losses and stats using parent methods losses = self.target_losses() all_stats = self.stats() for name, loss in losses.items(): try: loss_val = loss.item() if torch.is_tensor(loss) else loss print(f"{name:<25} {loss_val:>15.4f}") # Print detailed stats from parent stats() method stats = all_stats.get(name, {}) for stat_name, stat_val in stats.items(): if isinstance(stat_val, float): print(f" {stat_name:<23} {stat_val:>15.4f}") else: print(f" {stat_name:<23} {stat_val:>15}") except Exception as e: print(f"{name:<25} Error: {e}") print("-" * 90) total_loss = self.forward().item() print(f"{'TOTAL ADP LOSS':<25} {total_loss:>15.4f}") print("=" * 90 + "\n")
[docs] def get_metrics(self, verbosity: int = VERBOSITY_DETAILED) -> Dict[str, float]: """ Get all ADP metrics as a flat dictionary for logging/reporting. Parameters ---------- verbosity : int, optional Verbosity level for filtering. Default is VERBOSITY_DETAILED. Returns ------- dict Dictionary with validation metrics from all component targets. All values are Python floats (not tensors). """ metrics = {} # Total loss (always include) total_loss = self.forward() metrics["adp_total_loss"] = ( total_loss.item() if torch.is_tensor(total_loss) else total_loss ) # Get losses from target_losses() for name, loss in self.target_losses().items(): loss_val = loss.item() if torch.is_tensor(loss) else loss metrics[f"adp_{name}_loss"] = loss_val # Get statistics from parent stats() method (filter here) filtered_stats = filter_stats(self.stats(), verbosity) for name, target_stats in filtered_stats.items(): for stat_name, stat_val in target_stats.items(): metrics[f"adp_{name}_{stat_name}"] = stat_val return metrics