Source code for torchref.refinement.weighting.component_weighting

from typing import TYPE_CHECKING, Any, Dict, List

import torch
from torch import nn

from torchref.config import get_default_device
from torchref.refinement.weighting.base_weighting import BaseWeighting
from torchref.utils.device_mixin import DeviceMixin
from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_ESSENTIAL,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

if TYPE_CHECKING:
    from torchref.refinement.loss_state import LossState

from torchref.utils import TensorDict

# WeightingScheme is now an alias for BaseWeighting for backward compatibility
WeightingScheme = BaseWeighting


[docs] class ResolutionWeighting(BaseWeighting): """ Base prior strength with optional resolution-dependent correction. With proper NLL sums and perfectly calibrated sigmas, w=1 would be the pure Bayesian answer. However empirical sweep on 50 structures (1.15-3.0 A) showed the monomer library sigmas lead to too-loose geometry at w=1. The optimal base is: w_geometry ~= 10 (from median Rfree minimum) w_adp ~= 1 (lower is better for most resolutions) Geometry and ADP have different base weights because the monomer library geometry sigmas are tighter than the ADP restraint sigmas are loose, so they need different compensation. Optional resolution dependence: w_geometry = base_w_geometry * (d_min / d_ref) ^ alpha w_adp = base_w_adp * (d_min / d_ref) ^ alpha alpha=0 disables resolution correction. The sweep found the resolution dependence was weak (<20% effect), so alpha=0 is the default. Parameters ---------- device : torch.device, optional Computation device. base_w_geometry : float, optional Base geometry weight. Default 10.0 (from empirical sweep). base_w_adp : float, optional Base ADP weight. Default 1.0 (from empirical sweep). d_ref : float, optional Reference resolution for the optional power-law correction. Default 2.0 A. alpha : float, optional Resolution sensitivity. Default 0.0 (disabled). """ name = "resolution_weighting"
[docs] def __init__( self, device: torch.device = None, base_w_geometry: float = 1.0, base_w_adp: float = 1.0, d_ref: float = 2.0, alpha: float = 0.0, ): super().__init__(device) self.register_buffer("base_w_geometry", torch.tensor(base_w_geometry)) self.register_buffer("base_w_adp", torch.tensor(base_w_adp)) self.register_buffer("d_ref", torch.tensor(d_ref)) self.register_buffer("alpha", torch.tensor(alpha))
[docs] def forward(self, state: "LossState") -> Dict[str, float]: """Compute base weights with optional resolution correction.""" d_min = state.get("resolution_min", 2.0) if not isinstance(d_min, torch.Tensor): d_min = torch.tensor(d_min, device=self.device) res_factor = (d_min / self.d_ref) ** self.alpha w_geom = torch.clamp( self.base_w_geometry * res_factor, 0.1, 100.0 ).detach().item() w_adp = torch.clamp( self.base_w_adp * res_factor, 0.1, 100.0 ).detach().item() return {"geometry": w_geom, "adp": w_adp}
[docs] def stats(self, state: "LossState" = None) -> Dict[str, StatEntry]: if state is not None: d_min = state.get("resolution_min", 0.0) w = self.forward(state) else: d_min = 0.0 w = {"geometry": self.base_w_geometry.item(), "adp": self.base_w_adp.item()} return { "resolution_w_geom": stat(w["geometry"], VERBOSITY_ESSENTIAL), "resolution_w_adp": stat(w["adp"], VERBOSITY_ESSENTIAL), "d_min": stat(d_min, VERBOSITY_STANDARD), "d_ref": stat(self.d_ref.item(), VERBOSITY_DEBUG), "alpha": stat(self.alpha.item(), VERBOSITY_DEBUG), }
[docs] class OverfittingWeighting(BaseWeighting): """ Dynamic overfitting correction based on Rfree - Rwork gap. Uses R-factors (scale-invariant, per-reflection-normalized) rather than NLL values, which are incomparable between work/test sets after switching to summed NLLs. When the Rfree-Rwork gap exceeds target_gap, exponentially increases regularization. **The correction is applied primarily to ADP weights and only weakly to geometry weights**: in crystallographic refinement, overfitting is typically driven by B-factors (which have one parameter per atom and relatively weak restraints) rather than coordinates (which are held tightly by the geometry prior). Effective correction: factor = min_weight + exp(sharpness * (gap - target_gap)) w_adp *= factor w_geometry *= 1 + geom_share * (factor - 1) With geom_share = 0.2, only 20% of the overfitting correction is applied to geometry, keeping most of the effect on ADP. Tunable parameters (as buffers): - target_gap: R-factor gap threshold. Default 0.05 (5%). - min_weight: base correction factor. Default 1.0. - sharpness: exponential response steepness. Default 30.0. - geom_share: fraction of correction applied to geometry. Default 1.0. - smoothing: EMA smoothing factor (0-1). Default 0.8. """ name = "overfitting_weighting"
[docs] def __init__( self, device: torch.device = None, target_gap: float = 0.05, min_weight: float = 1.0, sharpness: float = 30.0, geom_share: float = 1.0, smoothing: float = 0.8, ): super().__init__(device) self.register_buffer("target_gap", torch.tensor(target_gap)) self.register_buffer("min_weight", torch.tensor(min_weight)) self.register_buffer("sharpness", torch.tensor(sharpness)) self.register_buffer("geom_share", torch.tensor(geom_share)) self.register_buffer("smoothing", torch.tensor(smoothing)) self.register_buffer("weight_reg", torch.tensor(1.0))
[docs] def forward(self, state: "LossState") -> Dict[str, float]: """Compute overfitting correction weights from R-factor gap.""" rwork = state.get("rwork", 0.0) rfree = state.get("rfree", 0.0) if not isinstance(rwork, torch.Tensor): rwork = torch.tensor(rwork, device=self.device) if not isinstance(rfree, torch.Tensor): rfree = torch.tensor(rfree, device=self.device) gap = rfree - rwork target_weight = self.min_weight + torch.exp( self.sharpness * (gap - self.target_gap) ) target_weight = target_weight.detach() self.weight_reg = ( self.smoothing * self.weight_reg + (1 - self.smoothing) * target_weight ) adp_factor = self.weight_reg.detach().item() # Apply only a share of the correction to geometry geom_factor = 1.0 + self.geom_share.item() * (adp_factor - 1.0) return { "geometry": geom_factor, "adp": adp_factor, }
[docs] def stats(self, state: "LossState" = None) -> Dict[str, StatEntry]: if state is not None: rwork = state.get("rwork", 0.0) rfree = state.get("rfree", 0.0) else: rwork = 0.0 rfree = 0.0 return { "overfitting_weight": stat(self.weight_reg.item(), VERBOSITY_ESSENTIAL), "target_gap": stat(self.target_gap.item(), VERBOSITY_DEBUG), "min_weight": stat(self.min_weight.item(), VERBOSITY_DEBUG), "sharpness": stat(self.sharpness.item(), VERBOSITY_DEBUG), "rwork": stat(rwork, VERBOSITY_STANDARD), "rfree": stat(rfree, VERBOSITY_STANDARD), }
[docs] class ManualWeighting(BaseWeighting): """ Apply fixed manual weights. This scheme doesn't need any state data - just returns the present weights. """ name = "manual_weighting"
[docs] def __init__(self, weights: Dict[str, float], device: torch.device = None): super().__init__(device) weights_as_tensor = { k: torch.tensor(v, device=self.device) for k, v in weights.items() } self.manual_weights = TensorDict(weights_as_tensor)
[docs] def forward(self, state: "LossState") -> Dict[str, float]: """Return manual weights (state is not used).""" return {k: v.item() for k, v in self.manual_weights.items()}
[docs] class ComponentWeighting(DeviceMixin, nn.Module): """ Combines multiple weighting schemes using nn.ModuleDict. Holds weighting schemes but does NOT hold a refinement reference. Weighting is computed via forward(state) which receives a LossState. Default schemes: - 'resolution': ResolutionWeighting - resolution-dependent prior strength - 'overfitting': OverfittingWeighting - prevents overfitting via Rfree gap Parameters ---------- device : torch.device, optional Computation device. weights : dict, optional Manual weight overrides. component_weights : dict, optional Manual weight overrides for specific components. schemes : list of BaseWeighting, optional Additional custom weighting schemes. Attributes ---------- schemes : nn.ModuleDict Dictionary of weighting schemes. """
[docs] def __init__( self, device: torch.device = None, weights: Dict[str, float] = None, component_weights: Dict[str, float] = None, schemes: List[BaseWeighting] = None, # Legacy parameter, ignored initial_xray_loss: float = None, ): super().__init__() self.device = device or get_default_device() schemes_dict = { # "resolution": ResolutionWeighting(device), "overfitting": OverfittingWeighting(device), } # Add manual weights if provided manual_weights_dict = {} if weights: manual_weights_dict.update(weights) if component_weights: manual_weights_dict.update(component_weights) if manual_weights_dict: schemes_dict["manual"] = ManualWeighting(manual_weights_dict, device) # Add additional schemes if schemes: for i, scheme in enumerate(schemes): key = getattr(scheme, "name", f"custom_{i}") schemes_dict[key] = scheme self.schemes = nn.ModuleDict(schemes_dict)
[docs] def __getitem__(self, key: str) -> BaseWeighting: """Get a scheme by name using dictionary-style access.""" return self.schemes[key]
[docs] def __contains__(self, key: str) -> bool: """Check if a scheme exists.""" return key in self.schemes
[docs] def keys(self): """Return scheme names.""" return self.schemes.keys()
[docs] def values(self): """Return scheme instances.""" return self.schemes.values()
[docs] def items(self): """Return (name, scheme) pairs.""" return self.schemes.items()
[docs] def add_scheme(self, name: str, scheme: BaseWeighting): """Add a new weighting scheme.""" self.schemes[name] = scheme
[docs] def forward(self, state: "LossState") -> Dict[str, float]: """ Compute weights from all schemes. Returns combined weights (multiplicative for shared keys). Does NOT modify state - just returns the computed weights. """ combined = {} for scheme in self.schemes.values(): scheme_weights = scheme.forward(state) for k, v in scheme_weights.items(): if k in combined: combined[k] = combined[k] * v else: combined[k] = v return combined
[docs] def total_loss_from_state(self, state: "LossState") -> torch.Tensor: """Compute total weighted loss from a LossState.""" return state.aggregate()
[docs] def stats(self, state: "LossState" = None) -> Dict[str, Any]: """Return statistics for reporting.""" stats = {} for name, scheme in self.schemes.items(): scheme_stats = scheme.stats(state) if scheme_stats: stats[name] = scheme_stats if state is not None: stats["weights"] = { k: stat(v if isinstance(v, (int, float)) else v, VERBOSITY_STANDARD) for k, v in state.weights.items() } rwork = state.get("rwork", 0.0) rfree = state.get("rfree", 0.0) stats["xray"] = { "rwork": stat(rwork, VERBOSITY_ESSENTIAL), "rfree": stat(rfree, VERBOSITY_ESSENTIAL), } return stats
__all__ = [ # Base class "BaseWeighting", # Weighting classes "WeightingScheme", # Alias for backward compatibility "ResolutionWeighting", "OverfittingWeighting", "ManualWeighting", "ComponentWeighting", ]