torchref.refinement.weighting.es_policy_training module

Evolution Strategy (ES) based policy training for weight prediction.

This approach: 1. Starts from a reasonably initialized policy 2. Spawns perturbations of the policy 3. Evaluates perturbations on structures 4. Scores by population-normalized final Rfree 5. Updates policy by advantage-weighted parameter averaging

Advantages over AWR: - Directly optimizes final Rfree (not per-step delta) - Population normalization controls for structure difficulty - Simpler, more robust to noise in the signal

class torchref.refinement.weighting.es_policy_training.ESConfig(population_size=20, sigma=0.1, learning_rate=0.01, n_structures_per_gen=10, n_steps=10, normalize_returns=True, antithetic=True)[source]

Bases: object

Configuration for ES training.

population_size: int = 20
sigma: float = 0.1
learning_rate: float = 0.01
n_structures_per_gen: int = 10
n_steps: int = 10
normalize_returns: bool = True
antithetic: bool = True
__init__(population_size=20, sigma=0.1, learning_rate=0.01, n_structures_per_gen=10, n_steps=10, normalize_returns=True, antithetic=True)
class torchref.refinement.weighting.es_policy_training.ESResult(pdb_id, perturbation_idx, epsilon, final_rfree, initial_rfree, delta_rfree, success, error=None)[source]

Bases: object

Results from one ES evaluation.

pdb_id: str
perturbation_idx: int
epsilon: ndarray
final_rfree: float
initial_rfree: float
delta_rfree: float
success: bool
error: str | None = None
__init__(pdb_id, perturbation_idx, epsilon, final_rfree, initial_rfree, delta_rfree, success, error=None)
class torchref.refinement.weighting.es_policy_training.GenerationResult(generation, results, mean_final_rfree, best_final_rfree, gradient_norm, param_update_norm)[source]

Bases: object

Results from one generation of ES.

generation: int
results: List[ESResult]
mean_final_rfree: float
best_final_rfree: float
gradient_norm: float
param_update_norm: float
__init__(generation, results, mean_final_rfree, best_final_rfree, gradient_norm, param_update_norm)
class torchref.refinement.weighting.es_policy_training.ESPolicyTrainer(policy, config=None, device=device(type='cpu'))[source]

Bases: object

Evolution Strategy trainer for weight prediction policy.

Uses Natural Evolution Strategies (NES) with population-normalized fitness to optimize a policy network for weight prediction.

Parameters:
  • policy (nn.Module) – Policy network to optimize (outputs log-weights)

  • config (ESConfig) – Training configuration

  • device (str) – Computation device

Example

policy = create_policy_network()
trainer = ESPolicyTrainer(policy, ESConfig())

# Run one generation
results = trainer.evaluate_generation(structures, run_trajectory_fn)
trainer.update(results)
__init__(policy, config=None, device=device(type='cpu'))[source]
history: List[GenerationResult]
get_flat_params()[source]

Get flattened policy parameters.

set_flat_params(flat_params)[source]

Set policy parameters from flattened array.

sample_perturbations()[source]

Sample perturbation vectors.

If antithetic=True, returns mirrored pairs for variance reduction.

create_perturbed_policy(epsilon)[source]

Create a copy of policy with perturbed parameters.

compute_advantages(results)[source]

Compute advantages for each perturbation.

Normalizes within each structure to remove structure difficulty effect. Returns {perturbation_idx: advantage}

compute_gradient(epsilons, advantages)[source]

Compute ES gradient estimate.

gradient = (1/n*sigma^2) * sum(advantage_i * epsilon_i)

update(generation_result)[source]

Apply gradient update to policy.

evaluate_generation(structures, run_trajectory_fn)[source]

Evaluate one generation of ES.

Parameters:
  • structures (list) – List of structure dicts with ‘pdb’, ‘mtz’, ‘pdb_id’ keys

  • run_trajectory_fn (callable) – Function to run a trajectory with a given policy

Returns:

Results including gradient update

Return type:

GenerationResult

save_checkpoint(path, state_dim=31, hidden_dim=256)[source]

Save training checkpoint.

load_checkpoint(path)[source]

Load training checkpoint.

torchref.refinement.weighting.es_policy_training.create_policy_network(state_dim=31, hidden_dim=256)[source]

Create the standard policy network architecture.

Uses 31-dimensional state vector matching meta_weighting_4 format: - Progress (2): progress, improvement_rate - R-factors (4): rwork, rfree, gap, delta_rfree - Static features (7): resolution_min, inv_resolution, log_n_atoms,

log_n_hkl, data_to_param_ratio, log_wilson_b, adp_cv

  • X-ray losses (3): work, test, work/test ratio

  • Geometry losses (7): total, bond, angle, torsion, planarity, chiral, nonbonded

  • Geometry RMSD (2): bond_rmsd, angle_rmsd

  • ADP losses (4): total, simu, locality, KL

  • ADP stats (2): mean_adp/wilson_b, adp_std/wilson_b

torchref.refinement.weighting.es_policy_training.initialize_policy_sensibly(state_dim=31, hidden_dim=256)[source]

Initialize policy with sensible default weights.

Uses 31-dimensional state vector matching meta_weighting_4 format.

Based on correlation analysis: - Higher xray weight is better - Lower torsion/KL weights are better