torchref.refinement.weighting.random_weighting module

Random Weighting Scheme for Policy Training Data Collection.

This module implements a random weighting scheme that samples component weights from log-normal distributions around default values. This is used to generate diverse training trajectories for initializing a policy network via AWR (Advantage-Weighted Regression).

The key insight is that by randomly sampling weights during refinement and recording the resulting R-free trajectories, we can train a policy to predict optimal weights based on refinement state.

Weight Sampling Strategy

Two-level sampling for stable yet diverse trajectories:

  1. Trajectory-level base weights: Sampled once at initialization with larger sigma (trajectory_sigma). These define the “character” of the trajectory.

  2. Step-level perturbations: Small perturbations applied at each step with smaller sigma (step_sigma). These add local variation while keeping the trajectory coherent.

For each component at each step:

base_log_weight ~ N(default, trajectory_sigma) [sampled once] step_perturbation ~ N(0, step_sigma) [sampled each step] log_weight = base_log_weight + step_perturbation weight = exp(log_weight)

References

  • “Advantage-Weighted Regression” (Peng et al., 2019)

  • design_choices.md in meta_weighting_2/

class torchref.refinement.weighting.random_weighting.RandomWeightingScheme(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None)[source]

Bases: BaseWeighting

Weighting scheme with two-level random sampling for stable trajectories.

Sampling strategy: 1. Base weights sampled once at initialization (trajectory_sigma) 2. Small perturbations applied at each step (step_sigma)

This creates diverse trajectories while keeping them coherent within a single refinement run.

Parameters:
  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • default_log_weights (dict, optional) – Default log-space weights for each component (mean of distribution). Default is DEFAULT_LOG_WEIGHTS.

  • trajectory_sigmas (dict, optional) – Standard deviations for trajectory-level base weight sampling. Default is DEFAULT_TRAJECTORY_SIGMAS (0.8 for all).

  • step_sigmas (dict, optional) – Standard deviations for step-level perturbations. Default is DEFAULT_STEP_SIGMAS (0.1 for all).

  • seed (int, optional) – Random seed for reproducibility. Default is None.

base_log_weights

Trajectory-level base weights (sampled once at init).

Type:

dict

current_log_weights

Current weights including step perturbation.

Type:

dict

current_weights

Current weights (exp of log_weights).

Type:

dict

sample_count

Number of times step perturbations have been applied.

Type:

int

name: str = 'random_weighting'
__init__(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None)[source]
apply_step_perturbation()[source]

Apply small perturbation to base weights for current step.

Returns:

Dictionary of perturbed weights.

Return type:

dict

resample_trajectory()[source]

Resample trajectory-level base weights.

Call this to start a completely new trajectory with different base weights.

Returns:

Dictionary of new base weights.

Return type:

dict

resample_weights()[source]

Apply step perturbation (for backwards compatibility).

This is equivalent to apply_step_perturbation().

Returns:

Dictionary of perturbed weights.

Return type:

dict

forward(state=None)[source]

Return current weights.

Parameters:

state (LossState, optional) – Current loss state (not used by this scheme, but required by interface).

Returns:

Dictionary mapping component names to weight values.

Return type:

dict

property sample_count: int

Get number of step perturbations applied.

get_log_weights()[source]

Get current log-space weights (base + perturbation).

Returns:

Dictionary of current log-space weights.

Return type:

dict

get_base_log_weights()[source]

Get trajectory-level base log-space weights.

Returns:

Dictionary of base log-space weights.

Return type:

dict

get_weights()[source]

Get current weights (linear space).

Returns:

Dictionary of current weights.

Return type:

dict

stats(state=None)[source]

Return statistics for reporting.

Parameters:

state (LossState, optional) – If provided, can pull data from LossState (not used by this scheme).

Returns:

Statistics dictionary with StatEntry objects.

Return type:

dict

class torchref.refinement.weighting.random_weighting.RandomComponentWeighting(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None, resample_each_step=True, initial_xray_loss=None)[source]

Bases: ComponentWeighting

Component weighting with two-level random sampling for data collection.

This is a drop-in replacement for ComponentWeighting that samples weights using a two-level strategy: 1. Trajectory-level base weights (sampled once at init) 2. Step-level perturbations (applied at each update)

The weighting combines: 1. XrayScaleWeighting - normalizes X-ray loss to consistent scale 2. RandomWeightingScheme - two-level random weight sampling

Parameters:
  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • default_log_weights (dict, optional) – Default log-space weights for each component.

  • trajectory_sigmas (dict, optional) – Standard deviations for trajectory-level sampling (default 0.8).

  • step_sigmas (dict, optional) – Standard deviations for step-level perturbations (default 0.1).

  • seed (int, optional) – Random seed for reproducibility.

  • resample_each_step (bool, optional) – If True, apply step perturbation at each compute_weights() call. If False, only apply perturbation when resample_weights() is called. Default is True.

  • initial_xray_loss (float, optional) – Initial X-ray loss for XrayScaleWeighting.

Examples

from torchref.refinement.weighting import RandomComponentWeighting
weighting = RandomComponentWeighting(device=device, seed=42)
# Base weights are sampled at init
base = weighting.get_base_log_weights()
# Each compute_weights applies small perturbation
weights = weighting.compute_weights(state)
current = weighting.get_sampled_log_weights()  # base + perturbation
__init__(device=None, default_log_weights=None, trajectory_sigmas=None, step_sigmas=None, seed=None, resample_each_step=True, initial_xray_loss=None)[source]
property random_scheme: RandomWeightingScheme

Get the random weighting scheme.

resample_trajectory()[source]

Resample trajectory-level base weights for a new trajectory.

Call this to start a completely new trajectory with different base weights.

Returns:

Dictionary of new base weights from the random scheme.

Return type:

dict

resample_weights()[source]

Apply step perturbation.

Returns:

Dictionary of newly perturbed weights.

Return type:

dict

compute_weights(state)[source]

Compute weights from all schemes.

If resample_each_step is True, also applies step perturbation. Returns combined weights (multiplicative for shared keys). Does NOT modify state - just returns the computed weights.

Parameters:

state (LossState) – State with meta and _losses populated.

Returns:

Dictionary of computed weights for each component.

Return type:

dict

get_sampled_log_weights()[source]

Get the current log-space weights (base + perturbation).

These are the “actions” taken by the random policy.

Returns:

Dictionary of log-space weights.

Return type:

dict

get_base_log_weights()[source]

Get the trajectory-level base log-space weights.

Returns:

Dictionary of base log-space weights.

Return type:

dict

get_sampled_weights()[source]

Get the weights from the random scheme (before xray scaling).

Returns:

Dictionary of sampled weights.

Return type:

dict

stats(state=None)[source]

Return statistics for reporting.

Parameters:

state (LossState, optional) – If provided, pull data from state.meta.

Returns:

Stats dictionary with StatEntry objects.

Return type:

dict