torchref.refinement.weighting.policy_weighting module
Policy-based component weighting for crystallographic refinement.
This module implements a weighting scheme that uses a trained neural network policy to predict component weights from the current refinement state.
Supports both: - Training mode: Sample from policy distribution for exploration - Evaluation mode: Use mean predictions for deterministic refinement
Integrates with LossState for hierarchical weight management.
- class torchref.refinement.weighting.policy_weighting.PolicyComponentWeighting(device=device(type='cpu'), policy_path=None, sample=False, temperature=1.0, n_steps=10, verbose=0)[source]
Bases:
BaseWeightingPolicy-based component weighting using a trained neural network.
This weighting scheme uses a policy network to predict component weights from the current refinement state. Supports both training (sampling) and evaluation (deterministic) modes.
Receives all data through LossState.meta and state._losses.
- Parameters:
device (torch.device, optional) – Computation device.
policy_path (str, optional) – Path to trained policy checkpoint (.pt file). If None, uses random init.
sample (bool, optional) – Whether to sample from policy distribution (default: False for eval)
temperature (float, optional) – Sampling temperature for exploration (default: 1.0) - 0.0: deterministic (use mean predictions) - > 0.0: sample from distribution with scaled variance
n_steps (int, optional) – Total number of steps in refinement cycle (default: 10).
verbose (int, optional) – Verbosity level (default: 0).
- policy
Policy network (loaded or randomly initialized)
- Type:
nn.Module
- trajectory
Current trajectory being recorded (if recording enabled)
- Type:
Examples
# Evaluation mode (deterministic) policy = PolicyComponentWeighting(device=device, policy_path='policy.pt', sample=False) # Training mode (sampling for exploration) policy = PolicyComponentWeighting(device=device, policy_path='policy.pt', sample=True, temperature=1.0) # Use with LossState state = refinement.create_loss_state() state = policy.forward(state)
- __init__(device=device(type='cpu'), policy_path=None, sample=False, temperature=1.0, n_steps=10, verbose=0)[source]
- set_training_mode(sample=True, temperature=1.0)[source]
Enable training mode (sampling from policy distribution).
- reset_trajectory(pdb_id='', structure_path='', sf_path='')[source]
Reset for a new trajectory and optionally start recording.
- start_recording(pdb_id, structure_path, sf_path, seed=None, policy_version=None)[source]
Start recording a new trajectory.
- extract_state_features(state)[source]
Extract 31-dimensional state features from LossState.
- Parameters:
state (LossState) – LossState with meta and _losses populated.
- Returns:
features (torch.Tensor, shape (31,)) – State feature vector for policy input
step_state (StepState) – State object for trajectory recording
Features extracted (31 total)
- Progress metrics (2) (progress, improvement_rate)
- R-factors (4) (rwork, rfree, gap, delta_rfree)
- Static structure/data features (7) (resolution, inv_resolution, log_n_atoms,) – log_n_hkl, data_to_param_ratio, log_wilson_b, b_cv
- X-ray losses (3) (work, test, work/test ratio)
- Geometry losses (7) (total, bond, angle, torsion, planarity, chiral, nonbonded)
- Geometry RMSD (2) (bond_rmsd, angle_rmsd)
- ADP losses (4) (total, simu, locality, KL)
- ADP stats (2) (mean_adp/wilson_b, adp_std/wilson_b)
- Return type:
- compute_weights(state)[source]
Compute component weights from LossState.
Deprecated: just call the method forward() or the instance directly.
- apply_to_state(state)[source]
Apply policy-predicted weights to a LossState.
This is the main integration point with the LossState architecture. Call this after creating a LossState to apply policy weights.
- Parameters:
state (LossState) – Loss state to update with policy weights.
- Returns:
State with policy weights applied.
- Return type:
Example
state = refinement.create_loss_state() state = policy_weighting.apply_to_state(state) loss = state.aggregate()
- class torchref.refinement.weighting.policy_weighting.StepState(step, progress, improvement_rate, rwork, rfree, rfree_gap, delta_rfree, xray_loss, xray_loss_test, xray_work_test_ratio, geometry_loss, adp_loss, bond_rmsd, angle_rmsd, mean_adp_normalized, adp_std_normalized, component_losses=<factory>, component_weights=<factory>)[source]
Bases:
objectState representation at a single refinement step for trajectory recording.
Matches the 31-feature format used in meta_weighting_4.
- __init__(step, progress, improvement_rate, rwork, rfree, rfree_gap, delta_rfree, xray_loss, xray_loss_test, xray_work_test_ratio, geometry_loss, adp_loss, bond_rmsd, angle_rmsd, mean_adp_normalized, adp_std_normalized, component_losses=<factory>, component_weights=<factory>)
- class torchref.refinement.weighting.policy_weighting.StepRecord(state, log_weights, weights, reward)[source]
Bases:
objectRecord for a single refinement step (state-action-reward tuple).
- __init__(state, log_weights, weights, reward)
- class torchref.refinement.weighting.policy_weighting.TrajectoryData(pdb_id, structure_path, sf_path, steps=<factory>, initial_rfree=0.0, final_rfree=0.0, initial_rwork=0.0, final_rwork=0.0, total_time=0.0, random_seed=None, policy_version=None, success=True, error_message=None)[source]
Bases:
objectComplete trajectory data for training.
- steps: List[StepRecord]
- __init__(pdb_id, structure_path, sf_path, steps=<factory>, initial_rfree=0.0, final_rfree=0.0, initial_rwork=0.0, final_rwork=0.0, total_time=0.0, random_seed=None, policy_version=None, success=True, error_message=None)