torchref.refinement.weighting.es_policy_training module
Evolution Strategy (ES) based policy training for weight prediction.
This approach: 1. Starts from a reasonably initialized policy 2. Spawns perturbations of the policy 3. Evaluates perturbations on structures 4. Scores by population-normalized final Rfree 5. Updates policy by advantage-weighted parameter averaging
Advantages over AWR: - Directly optimizes final Rfree (not per-step delta) - Population normalization controls for structure difficulty - Simpler, more robust to noise in the signal
- class torchref.refinement.weighting.es_policy_training.ESConfig(population_size=20, sigma=0.1, learning_rate=0.01, n_structures_per_gen=10, n_steps=10, normalize_returns=True, antithetic=True)[source]
Bases:
objectConfiguration for ES training.
- __init__(population_size=20, sigma=0.1, learning_rate=0.01, n_structures_per_gen=10, n_steps=10, normalize_returns=True, antithetic=True)
- class torchref.refinement.weighting.es_policy_training.ESResult(pdb_id, perturbation_idx, epsilon, final_rfree, initial_rfree, delta_rfree, success, error=None)[source]
Bases:
objectResults from one ES evaluation.
- __init__(pdb_id, perturbation_idx, epsilon, final_rfree, initial_rfree, delta_rfree, success, error=None)
- class torchref.refinement.weighting.es_policy_training.GenerationResult(generation, results, mean_final_rfree, best_final_rfree, gradient_norm, param_update_norm)[source]
Bases:
objectResults from one generation of ES.
- __init__(generation, results, mean_final_rfree, best_final_rfree, gradient_norm, param_update_norm)
- class torchref.refinement.weighting.es_policy_training.ESPolicyTrainer(policy, config=None, device=device(type='cpu'))[source]
Bases:
objectEvolution Strategy trainer for weight prediction policy.
Uses Natural Evolution Strategies (NES) with population-normalized fitness to optimize a policy network for weight prediction.
- Parameters:
Example
policy = create_policy_network() trainer = ESPolicyTrainer(policy, ESConfig()) # Run one generation results = trainer.evaluate_generation(structures, run_trajectory_fn) trainer.update(results)
- history: List[GenerationResult]
- sample_perturbations()[source]
Sample perturbation vectors.
If antithetic=True, returns mirrored pairs for variance reduction.
- compute_advantages(results)[source]
Compute advantages for each perturbation.
Normalizes within each structure to remove structure difficulty effect. Returns {perturbation_idx: advantage}
- compute_gradient(epsilons, advantages)[source]
Compute ES gradient estimate.
gradient = (1/n*sigma^2) * sum(advantage_i * epsilon_i)
- evaluate_generation(structures, run_trajectory_fn)[source]
Evaluate one generation of ES.
- Parameters:
structures (list) – List of structure dicts with ‘pdb’, ‘mtz’, ‘pdb_id’ keys
run_trajectory_fn (callable) – Function to run a trajectory with a given policy
- Returns:
Results including gradient update
- Return type:
- torchref.refinement.weighting.es_policy_training.create_policy_network(state_dim=31, hidden_dim=256)[source]
Create the standard policy network architecture.
Uses 31-dimensional state vector matching meta_weighting_4 format: - Progress (2): progress, improvement_rate - R-factors (4): rwork, rfree, gap, delta_rfree - Static features (7): resolution_min, inv_resolution, log_n_atoms,
log_n_hkl, data_to_param_ratio, log_wilson_b, adp_cv
X-ray losses (3): work, test, work/test ratio
Geometry losses (7): total, bond, angle, torsion, planarity, chiral, nonbonded
Geometry RMSD (2): bond_rmsd, angle_rmsd
ADP losses (4): total, simu, locality, KL
ADP stats (2): mean_adp/wilson_b, adp_std/wilson_b
- torchref.refinement.weighting.es_policy_training.initialize_policy_sensibly(state_dim=31, hidden_dim=256)[source]
Initialize policy with sensible default weights.
Uses 31-dimensional state vector matching meta_weighting_4 format.
Based on correlation analysis: - Higher xray weight is better - Lower torsion/KL weights are better