torchref.refinement.weighting.random_weighting module
Random Weighting Scheme for Policy Training Data Collection.
This module implements a random weighting scheme that samples component weights from log-normal distributions around default values. This is used to generate diverse training trajectories for initializing a policy network via AWR (Advantage-Weighted Regression).
The key insight is that by randomly sampling weights during refinement and recording the resulting R-free trajectories, we can train a policy to predict optimal weights based on refinement state.
Weight Sampling Strategy
Two-level sampling for stable yet diverse trajectories:
Trajectory-level base weights: Sampled once at initialization with larger sigma (trajectory_sigma). These define the “character” of the trajectory.
Step-level perturbations: Small perturbations applied at each step with smaller sigma (step_sigma). These add local variation while keeping the trajectory coherent.
- For each component at each step:
base_log_weight ~ N(default, trajectory_sigma) [sampled once] step_perturbation ~ N(0, step_sigma) [sampled each step] log_weight = base_log_weight + step_perturbation weight = exp(log_weight)
References
“Advantage-Weighted Regression” (Peng et al., 2019)
design_choices.md in meta_weighting_2/
- class torchref.refinement.weighting.random_weighting.RandomWeightingScheme(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None)[source]
Bases:
BaseWeightingWeighting scheme with two-level random sampling for stable trajectories.
Sampling strategy: 1. Base weights sampled once at initialization (trajectory_sigma) 2. Small perturbations applied at each step (step_sigma)
This creates diverse trajectories while keeping them coherent within a single refinement run.
- Parameters:
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
default_log_weights (dict, optional) – Default log-space weights for each component (mean of distribution). Default is DEFAULT_LOG_WEIGHTS.
trajectory_sigmas (dict, optional) – Standard deviations for trajectory-level base weight sampling. Default is DEFAULT_TRAJECTORY_SIGMAS (0.8 for all).
step_sigmas (dict, optional) – Standard deviations for step-level perturbations. Default is DEFAULT_STEP_SIGMAS (0.1 for all).
seed (int, optional) – Random seed for reproducibility. Default is None.
- __init__(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None)[source]
- apply_step_perturbation()[source]
Apply small perturbation to base weights for current step.
- Returns:
Dictionary of perturbed weights.
- Return type:
- resample_trajectory()[source]
Resample trajectory-level base weights.
Call this to start a completely new trajectory with different base weights.
- Returns:
Dictionary of new base weights.
- Return type:
- resample_weights()[source]
Apply step perturbation (for backwards compatibility).
This is equivalent to apply_step_perturbation().
- Returns:
Dictionary of perturbed weights.
- Return type:
- get_log_weights()[source]
Get current log-space weights (base + perturbation).
- Returns:
Dictionary of current log-space weights.
- Return type:
- get_base_log_weights()[source]
Get trajectory-level base log-space weights.
- Returns:
Dictionary of base log-space weights.
- Return type:
- class torchref.refinement.weighting.random_weighting.RandomComponentWeighting(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None, resample_each_step=True, initial_xray_loss=None)[source]
Bases:
ComponentWeightingComponent weighting with two-level random sampling for data collection.
This is a drop-in replacement for ComponentWeighting that samples weights using a two-level strategy: 1. Trajectory-level base weights (sampled once at init) 2. Step-level perturbations (applied at each update)
The weighting combines: 1. XrayScaleWeighting - normalizes X-ray loss to consistent scale 2. RandomWeightingScheme - two-level random weight sampling
- Parameters:
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
default_log_weights (dict, optional) – Default log-space weights for each component.
trajectory_sigmas (dict, optional) – Standard deviations for trajectory-level sampling (default 0.8).
step_sigmas (dict, optional) – Standard deviations for step-level perturbations (default 0.1).
seed (int, optional) – Random seed for reproducibility.
resample_each_step (bool, optional) – If True, apply step perturbation at each compute_weights() call. If False, only apply perturbation when resample_weights() is called. Default is True.
initial_xray_loss (float, optional) – Initial X-ray loss for XrayScaleWeighting.
Examples
from torchref.refinement.weighting import RandomComponentWeighting weighting = RandomComponentWeighting(device=device, seed=42) # Base weights are sampled at init base = weighting.get_base_log_weights() # Each compute_weights applies small perturbation weights = weighting.compute_weights(state) current = weighting.get_sampled_log_weights() # base + perturbation
- __init__(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None, resample_each_step=True, initial_xray_loss=None)[source]
- property random_scheme: RandomWeightingScheme
Get the random weighting scheme.
- resample_trajectory()[source]
Resample trajectory-level base weights for a new trajectory.
Call this to start a completely new trajectory with different base weights.
- Returns:
Dictionary of new base weights from the random scheme.
- Return type:
- resample_weights()[source]
Apply step perturbation.
- Returns:
Dictionary of newly perturbed weights.
- Return type:
- compute_weights(state)[source]
Compute weights from all schemes.
If resample_each_step is True, also applies step perturbation. Returns combined weights (multiplicative for shared keys). Does NOT modify state - just returns the computed weights.
- get_sampled_log_weights()[source]
Get the current log-space weights (base + perturbation).
These are the “actions” taken by the random policy.
- Returns:
Dictionary of log-space weights.
- Return type:
- get_base_log_weights()[source]
Get the trajectory-level base log-space weights.
- Returns:
Dictionary of base log-space weights.
- Return type: