"""
Random Weighting Scheme for Policy Training Data Collection.
This module implements a random weighting scheme that samples component weights
from log-normal distributions around default values. This is used to generate
diverse training trajectories for initializing a policy network via AWR
(Advantage-Weighted Regression).
The key insight is that by randomly sampling weights during refinement and
recording the resulting R-free trajectories, we can train a policy to predict
optimal weights based on refinement state.
Weight Sampling Strategy
------------------------
Two-level sampling for stable yet diverse trajectories:
1. **Trajectory-level base weights**: Sampled once at initialization with larger
sigma (trajectory_sigma). These define the "character" of the trajectory.
2. **Step-level perturbations**: Small perturbations applied at each step with
smaller sigma (step_sigma). These add local variation while keeping the
trajectory coherent.
For each component at each step:
base_log_weight ~ N(default, trajectory_sigma) [sampled once]
step_perturbation ~ N(0, step_sigma) [sampled each step]
log_weight = base_log_weight + step_perturbation
weight = exp(log_weight)
References
----------
- "Advantage-Weighted Regression" (Peng et al., 2019)
- design_choices.md in meta_weighting_2/
"""
from typing import TYPE_CHECKING, Any, Dict
import numpy as np
import torch
from torch import nn
from torchref.config import get_default_device
from torchref.refinement.weighting.base_weighting import BaseWeighting
from torchref.refinement.weighting.component_weighting import (
ComponentWeighting,
ResolutionWeighting,
)
from torchref.utils.stats import (
VERBOSITY_DETAILED,
VERBOSITY_ESSENTIAL,
VERBOSITY_STANDARD,
StatEntry,
stat,
)
if TYPE_CHECKING:
from torchref.refinement.loss_state import LossState
# Default log-space weights (these are the log of the actual weights)
# weight = exp(log_weight)
DEFAULT_LOG_WEIGHTS = {
"xray": 0.0,
# Geometry components
"bond": 0.0,
"angle": 0.0,
"torsion": 0.0,
"planarity": 0.0,
"chiral": 0.0,
"nonbonded": 0.0,
# ADP components
"simu": 0.0,
"locality": 0.0,
"KL": 0.0,
}
# Default trajectory-level sigmas (larger - defines trajectory character)
# Increased for more diverse exploration
DEFAULT_TRAJECTORY_SIGMAS = {
"xray": 1.5,
"bond": 1.5,
"angle": 1.5,
"torsion": 1.5,
"planarity": 1.5,
"chiral": 1.5,
"nonbonded": 1.5,
"simu": 1.5,
"locality": 1.5,
"KL": 1.5,
}
# Default step-level sigmas (smaller - local perturbations)
# Increased for more variation within trajectories
DEFAULT_STEP_SIGMAS = {
"xray": 0.3,
"bond": 0.3,
"angle": 0.3,
"torsion": 0.3,
"planarity": 0.3,
"chiral": 0.3,
"nonbonded": 0.3,
"simu": 0.3,
"locality": 0.3,
"KL": 0.3,
}
# Bounds for log-weights to prevent extreme values
LOG_WEIGHT_MIN = -3.0
LOG_WEIGHT_MAX = 3.0
[docs]
class RandomWeightingScheme(BaseWeighting):
"""
Weighting scheme with two-level random sampling for stable trajectories.
Sampling strategy:
1. Base weights sampled once at initialization (trajectory_sigma)
2. Small perturbations applied at each step (step_sigma)
This creates diverse trajectories while keeping them coherent within
a single refinement run.
Parameters
----------
device : torch.device, optional
Computation device. Defaults to the configured device.current.
default_log_weights : dict, optional
Default log-space weights for each component (mean of distribution).
Default is DEFAULT_LOG_WEIGHTS.
trajectory_sigmas : dict, optional
Standard deviations for trajectory-level base weight sampling.
Default is DEFAULT_TRAJECTORY_SIGMAS (0.8 for all).
step_sigmas : dict, optional
Standard deviations for step-level perturbations.
Default is DEFAULT_STEP_SIGMAS (0.1 for all).
seed : int, optional
Random seed for reproducibility. Default is None.
Attributes
----------
base_log_weights : dict
Trajectory-level base weights (sampled once at init).
current_log_weights : dict
Current weights including step perturbation.
current_weights : dict
Current weights (exp of log_weights).
sample_count : int
Number of times step perturbations have been applied.
"""
name = "random_weighting"
[docs]
def __init__(
self,
device: torch.device = None,
default_log_weights: Dict[str, float] = None,
trajectory_sigmas: Dict[str, float] = None,
step_sigmas: Dict[str, float] = None,
seed: int = None,
):
super().__init__(device)
# Set defaults
self.default_log_weights = default_log_weights or DEFAULT_LOG_WEIGHTS.copy()
self.trajectory_sigmas = trajectory_sigmas or DEFAULT_TRAJECTORY_SIGMAS.copy()
self.step_sigmas = step_sigmas or DEFAULT_STEP_SIGMAS.copy()
# Initialize RNG
if seed is not None:
self.rng = np.random.RandomState(seed)
else:
self.rng = np.random.RandomState()
# Register buffers for base weights (trajectory-level, sampled once)
for name, default in self.default_log_weights.items():
self.register_buffer(f"base_log_weight_{name}", torch.tensor(default))
self.register_buffer(f"log_weight_{name}", torch.tensor(default))
# Track sampling
self.register_buffer("_sample_count", torch.tensor(0, dtype=torch.long))
self.register_buffer("_trajectory_initialized", torch.tensor(False))
# Initialize base weights and current weights
self.base_log_weights = {}
self.current_log_weights = {}
self.current_weights = {}
# Sample trajectory base weights
self._sample_base_weights()
self._update_weight_tensors()
def _sample_base_weights(self):
"""Sample trajectory-level base weights (called once at init)."""
for name, default_log in self.default_log_weights.items():
sigma = self.trajectory_sigmas.get(name, 0.8)
# Sample base weight in log-space
base_log_weight = self.rng.normal(default_log, sigma)
# Clip to bounds
base_log_weight = np.clip(base_log_weight, LOG_WEIGHT_MIN, LOG_WEIGHT_MAX)
# Update buffer
getattr(self, f"base_log_weight_{name}").fill_(base_log_weight)
getattr(self, f"log_weight_{name}").fill_(base_log_weight)
self.base_log_weights[name] = float(base_log_weight)
self._trajectory_initialized.fill_(True)
def _update_weight_tensors(self):
"""Update weight tensors from current log_weight buffers."""
self.current_log_weights = {}
self.current_weights = {}
for name in self.default_log_weights.keys():
log_w = getattr(self, f"log_weight_{name}")
self.current_log_weights[name] = float(log_w.item())
self.current_weights[name] = float(torch.exp(log_w).item())
[docs]
def apply_step_perturbation(self) -> Dict[str, float]:
"""
Apply small perturbation to base weights for current step.
Returns
-------
dict
Dictionary of perturbed weights.
"""
for name in self.default_log_weights.keys():
base_log = self.base_log_weights[name]
step_sigma = self.step_sigmas.get(name, 0.1)
# Sample small perturbation
perturbation = self.rng.normal(0.0, step_sigma)
# Apply perturbation to base weight
log_weight = base_log + perturbation
# Clip to bounds
log_weight = np.clip(log_weight, LOG_WEIGHT_MIN, LOG_WEIGHT_MAX)
# Update buffer
getattr(self, f"log_weight_{name}").fill_(log_weight)
self._sample_count.add_(1)
self._update_weight_tensors()
return self.current_weights.copy()
[docs]
def resample_trajectory(self) -> Dict[str, float]:
"""
Resample trajectory-level base weights.
Call this to start a completely new trajectory with different
base weights.
Returns
-------
dict
Dictionary of new base weights.
"""
self._sample_base_weights()
self._update_weight_tensors()
self._sample_count.fill_(0)
return self.current_weights.copy()
[docs]
def resample_weights(self) -> Dict[str, float]:
"""
Apply step perturbation (for backwards compatibility).
This is equivalent to apply_step_perturbation().
Returns
-------
dict
Dictionary of perturbed weights.
"""
return self.apply_step_perturbation()
[docs]
def forward(self, state: "LossState" = None) -> Dict[str, float]:
"""
Return current weights.
Parameters
----------
state : LossState, optional
Current loss state (not used by this scheme, but required by interface).
Returns
-------
dict
Dictionary mapping component names to weight values.
"""
return self.current_weights.copy()
@property
def sample_count(self) -> int:
"""Get number of step perturbations applied."""
return self._sample_count.item()
[docs]
def get_log_weights(self) -> Dict[str, float]:
"""
Get current log-space weights (base + perturbation).
Returns
-------
dict
Dictionary of current log-space weights.
"""
return self.current_log_weights.copy()
[docs]
def get_base_log_weights(self) -> Dict[str, float]:
"""
Get trajectory-level base log-space weights.
Returns
-------
dict
Dictionary of base log-space weights.
"""
return self.base_log_weights.copy()
[docs]
def get_weights(self) -> Dict[str, float]:
"""
Get current weights (linear space).
Returns
-------
dict
Dictionary of current weights.
"""
return self.current_weights.copy()
[docs]
def stats(self, state: "LossState" = None) -> Dict[str, StatEntry]:
"""
Return statistics for reporting.
Parameters
----------
state : LossState, optional
If provided, can pull data from LossState (not used by this scheme).
Returns
-------
dict
Statistics dictionary with StatEntry objects.
"""
stats = {
"sample_count": stat(self.sample_count, VERBOSITY_STANDARD),
}
# Add current weights and base weights
for name, weight in self.current_weights.items():
stats[f"weight_{name}"] = stat(weight, VERBOSITY_STANDARD)
stats[f"log_weight_{name}"] = stat(
self.current_log_weights[name], VERBOSITY_DETAILED
)
stats[f"base_log_weight_{name}"] = stat(
self.base_log_weights.get(name, 0.0), VERBOSITY_DETAILED
)
return stats
[docs]
class RandomComponentWeighting(ComponentWeighting):
"""
Component weighting with two-level random sampling for data collection.
This is a drop-in replacement for ComponentWeighting that samples
weights using a two-level strategy:
1. Trajectory-level base weights (sampled once at init)
2. Step-level perturbations (applied at each update)
The weighting combines:
1. XrayScaleWeighting - normalizes X-ray loss to consistent scale
2. RandomWeightingScheme - two-level random weight sampling
Parameters
----------
device : torch.device, optional
Computation device. Defaults to the configured device.current.
default_log_weights : dict, optional
Default log-space weights for each component.
trajectory_sigmas : dict, optional
Standard deviations for trajectory-level sampling (default 0.8).
step_sigmas : dict, optional
Standard deviations for step-level perturbations (default 0.1).
seed : int, optional
Random seed for reproducibility.
resample_each_step : bool, optional
If True, apply step perturbation at each compute_weights() call.
If False, only apply perturbation when resample_weights() is called.
Default is True.
initial_xray_loss : float, optional
Initial X-ray loss for XrayScaleWeighting.
Examples
--------
::
from torchref.refinement.weighting import RandomComponentWeighting
weighting = RandomComponentWeighting(device=device, seed=42)
# Base weights are sampled at init
base = weighting.get_base_log_weights()
# Each compute_weights applies small perturbation
weights = weighting.compute_weights(state)
current = weighting.get_sampled_log_weights() # base + perturbation
"""
[docs]
def __init__(
self,
device: torch.device = None,
default_log_weights: Dict[str, float] = None,
trajectory_sigmas: Dict[str, float] = None,
step_sigmas: Dict[str, float] = None,
seed: int = None,
resample_each_step: bool = True,
initial_xray_loss: float = None,
):
# Don't call parent __init__ - we'll set up our own schemes
nn.Module.__init__(self)
self.device = device or get_default_device()
self.resample_each_step = resample_each_step
# Build schemes dict with resolution weighting and random
schemes_dict = {
"resolution": ResolutionWeighting(device),
"random": RandomWeightingScheme(
device,
default_log_weights=default_log_weights,
trajectory_sigmas=trajectory_sigmas,
step_sigmas=step_sigmas,
seed=seed,
),
}
self.schemes = nn.ModuleDict(schemes_dict)
@property
def random_scheme(self) -> RandomWeightingScheme:
"""Get the random weighting scheme."""
return self.schemes["random"]
[docs]
def resample_trajectory(self) -> Dict[str, float]:
"""
Resample trajectory-level base weights for a new trajectory.
Call this to start a completely new trajectory with different
base weights.
Returns
-------
dict
Dictionary of new base weights from the random scheme.
"""
return self.random_scheme.resample_trajectory()
[docs]
def resample_weights(self) -> Dict[str, float]:
"""
Apply step perturbation.
Returns
-------
dict
Dictionary of newly perturbed weights.
"""
return self.random_scheme.apply_step_perturbation()
[docs]
def compute_weights(self, state: "LossState") -> Dict[str, float]:
"""
Compute weights from all schemes.
If resample_each_step is True, also applies step perturbation.
Returns combined weights (multiplicative for shared keys).
Does NOT modify state - just returns the computed weights.
Parameters
----------
state : LossState
State with meta and _losses populated.
Returns
-------
dict
Dictionary of computed weights for each component.
"""
if self.resample_each_step:
self.random_scheme.apply_step_perturbation()
combined = {}
# Combine weights from all schemes (multiply for shared keys)
for scheme in self.schemes.values():
scheme_weights = scheme.forward(state)
for k, v in scheme_weights.items():
if k in combined:
combined[k] = combined[k] * v
else:
combined[k] = v
return combined
[docs]
def get_sampled_log_weights(self) -> Dict[str, float]:
"""
Get the current log-space weights (base + perturbation).
These are the "actions" taken by the random policy.
Returns
-------
dict
Dictionary of log-space weights.
"""
return self.random_scheme.get_log_weights()
[docs]
def get_base_log_weights(self) -> Dict[str, float]:
"""
Get the trajectory-level base log-space weights.
Returns
-------
dict
Dictionary of base log-space weights.
"""
return self.random_scheme.get_base_log_weights()
[docs]
def get_sampled_weights(self) -> Dict[str, float]:
"""
Get the weights from the random scheme (before xray scaling).
Returns
-------
dict
Dictionary of sampled weights.
"""
return self.random_scheme.get_weights()
[docs]
def stats(self, state: "LossState" = None) -> Dict[str, Any]:
"""
Return statistics for reporting.
Parameters
----------
state : LossState, optional
If provided, pull data from state.meta.
Returns
-------
dict
Stats dictionary with StatEntry objects.
"""
stats = {}
# Collect stats from schemes
for name, scheme in self.schemes.items():
scheme_stats = scheme.stats(state)
if scheme_stats:
stats[name] = scheme_stats
# Add sampled log-weights (the "actions")
stats["sampled_log_weights"] = {
k: stat(v, VERBOSITY_STANDARD)
for k, v in self.get_sampled_log_weights().items()
}
# Add xray stats from state.meta
if state is not None:
work_nll = state.get("xray_loss_work", 0.0)
test_nll = state.get("xray_loss_test", 0.0)
stats["xray"] = {
"work_nll": stat(work_nll, VERBOSITY_ESSENTIAL),
"test_nll": stat(test_nll, VERBOSITY_ESSENTIAL),
}
# Add current weights from state
stats["weights"] = {
k: stat(v if isinstance(v, (int, float)) else v, VERBOSITY_STANDARD)
for k, v in state.weights.items()
}
return stats
__all__ = [
"RandomWeightingScheme",
"RandomComponentWeighting",
"DEFAULT_LOG_WEIGHTS",
"DEFAULT_TRAJECTORY_SIGMAS",
"DEFAULT_STEP_SIGMAS",
]