"""
Evolution Strategy (ES) based policy training for weight prediction.
This approach:
1. Starts from a reasonably initialized policy
2. Spawns perturbations of the policy
3. Evaluates perturbations on structures
4. Scores by population-normalized final Rfree
5. Updates policy by advantage-weighted parameter averaging
Advantages over AWR:
- Directly optimizes final Rfree (not per-step delta)
- Population normalization controls for structure difficulty
- Simpler, more robust to noise in the signal
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
from torchref.config import get_default_device
[docs]
@dataclass
class ESConfig:
"""Configuration for ES training."""
population_size: int = 20 # Number of perturbations per structure
sigma: float = 0.1 # Perturbation standard deviation
learning_rate: float = 0.01 # Update step size
n_structures_per_gen: int = 10 # Structures evaluated per generation
n_steps: int = 10 # Refinement steps per trajectory
normalize_returns: bool = True # Normalize within structure
antithetic: bool = True # Use mirrored perturbations (more stable)
[docs]
@dataclass
class ESResult:
"""Results from one ES evaluation."""
pdb_id: str
perturbation_idx: int
epsilon: np.ndarray # The noise vector used
final_rfree: float
initial_rfree: float
delta_rfree: float
success: bool
error: Optional[str] = None
[docs]
@dataclass
class GenerationResult:
"""Results from one generation of ES."""
generation: int
results: List[ESResult]
mean_final_rfree: float
best_final_rfree: float
gradient_norm: float
param_update_norm: float
[docs]
class ESPolicyTrainer:
"""
Evolution Strategy trainer for weight prediction policy.
Uses Natural Evolution Strategies (NES) with population-normalized
fitness to optimize a policy network for weight prediction.
Parameters
----------
policy : nn.Module
Policy network to optimize (outputs log-weights)
config : ESConfig
Training configuration
device : str
Computation device
Example
-------
::
policy = create_policy_network()
trainer = ESPolicyTrainer(policy, ESConfig())
# Run one generation
results = trainer.evaluate_generation(structures, run_trajectory_fn)
trainer.update(results)
"""
[docs]
def __init__(
self,
policy: nn.Module,
config: ESConfig = None,
device=get_default_device(),
):
self.policy = policy
self.config = config or ESConfig()
self.device = torch.device(device)
self.policy.to(self.device)
# Get total number of parameters
self.n_params = sum(p.numel() for p in policy.parameters())
# Flatten parameters for easier manipulation
self._param_shapes = [p.shape for p in policy.parameters()]
# Training history
self.generation = 0
self.history: List[GenerationResult] = []
[docs]
def get_flat_params(self) -> np.ndarray:
"""Get flattened policy parameters."""
return np.concatenate(
[p.detach().cpu().numpy().flatten() for p in self.policy.parameters()]
)
[docs]
def set_flat_params(self, flat_params: np.ndarray):
"""Set policy parameters from flattened array."""
offset = 0
for p, shape in zip(self.policy.parameters(), self._param_shapes):
size = np.prod(shape)
p.data.copy_(
torch.tensor(
flat_params[offset : offset + size].reshape(shape),
dtype=p.dtype,
device=p.device,
)
)
offset += size
[docs]
def sample_perturbations(self) -> List[np.ndarray]:
"""
Sample perturbation vectors.
If antithetic=True, returns mirrored pairs for variance reduction.
"""
n = self.config.population_size
if self.config.antithetic:
# Sample half, mirror to get full population
half_n = n // 2
epsilons = [
np.random.randn(self.n_params) * self.config.sigma
for _ in range(half_n)
]
# Add mirrored versions
epsilons = epsilons + [-e for e in epsilons]
else:
epsilons = [
np.random.randn(self.n_params) * self.config.sigma for _ in range(n)
]
return epsilons
[docs]
def create_perturbed_policy(self, epsilon: np.ndarray) -> nn.Module:
"""Create a copy of policy with perturbed parameters."""
import copy
perturbed = copy.deepcopy(self.policy)
base_params = self.get_flat_params()
perturbed_params = base_params + epsilon
# Set parameters on the copy
offset = 0
for p, shape in zip(perturbed.parameters(), self._param_shapes):
size = np.prod(shape)
p.data.copy_(
torch.tensor(
perturbed_params[offset : offset + size].reshape(shape),
dtype=p.dtype,
device=self.device,
)
)
offset += size
return perturbed
[docs]
def compute_advantages(self, results: List[ESResult]) -> Dict[int, float]:
"""
Compute advantages for each perturbation.
Normalizes within each structure to remove structure difficulty effect.
Returns {perturbation_idx: advantage}
"""
# Group by structure
by_structure: Dict[str, List[ESResult]] = {}
for r in results:
if r.success:
by_structure.setdefault(r.pdb_id, []).append(r)
advantages = {}
for pdb_id, struct_results in by_structure.items():
# Get final Rfree values
rfrees = np.array([r.final_rfree for r in struct_results])
if self.config.normalize_returns:
# Normalize within structure: z-score
mean_rfree = rfrees.mean()
std_rfree = rfrees.std()
if std_rfree > 1e-8:
normalized = (rfrees - mean_rfree) / std_rfree
else:
normalized = np.zeros_like(rfrees)
# Lower Rfree = higher advantage (negate)
struct_advantages = -normalized
else:
# Just negate: lower Rfree = higher advantage
struct_advantages = -rfrees
# Assign to perturbation indices
for r, adv in zip(struct_results, struct_advantages):
advantages[r.perturbation_idx] = adv
return advantages
[docs]
def compute_gradient(
self, epsilons: List[np.ndarray], advantages: Dict[int, float]
) -> np.ndarray:
"""
Compute ES gradient estimate.
gradient = (1/n*sigma^2) * sum(advantage_i * epsilon_i)
"""
n = len(epsilons)
sigma = self.config.sigma
gradient = np.zeros(self.n_params)
for idx, eps in enumerate(epsilons):
if idx in advantages:
gradient += advantages[idx] * eps
# Normalize by population size and sigma squared
gradient /= n * sigma**2
return gradient
[docs]
def update(self, generation_result: GenerationResult):
"""Apply gradient update to policy."""
# Already computed in evaluate_generation
self.generation += 1
self.history.append(generation_result)
[docs]
def evaluate_generation(
self,
structures: List[
Dict[str, str]
], # [{'pdb': path, 'mtz': path, 'pdb_id': id}, ...]
run_trajectory_fn, # callable(policy, pdb, mtz, n_steps) -> (final_rfree, initial_rfree, success, error)
) -> GenerationResult:
"""
Evaluate one generation of ES.
Parameters
----------
structures : list
List of structure dicts with 'pdb', 'mtz', 'pdb_id' keys
run_trajectory_fn : callable
Function to run a trajectory with a given policy
Returns
-------
GenerationResult
Results including gradient update
"""
epsilons = self.sample_perturbations()
results = []
for struct in structures:
pdb_id = struct.get("pdb_id", Path(struct["pdb"]).stem)
for idx, epsilon in enumerate(epsilons):
# Create perturbed policy
perturbed_policy = self.create_perturbed_policy(epsilon)
try:
final_rfree, initial_rfree, success, error = run_trajectory_fn(
perturbed_policy,
struct["pdb"],
struct["mtz"],
self.config.n_steps,
)
results.append(
ESResult(
pdb_id=pdb_id,
perturbation_idx=idx,
epsilon=epsilon,
final_rfree=final_rfree if success else float("inf"),
initial_rfree=initial_rfree,
delta_rfree=final_rfree - initial_rfree if success else 0,
success=success,
error=error,
)
)
except Exception as e:
results.append(
ESResult(
pdb_id=pdb_id,
perturbation_idx=idx,
epsilon=epsilon,
final_rfree=float("inf"),
initial_rfree=0,
delta_rfree=0,
success=False,
error=str(e),
)
)
# Compute advantages and gradient
advantages = self.compute_advantages(results)
gradient = self.compute_gradient(epsilons, advantages)
# Update parameters
current_params = self.get_flat_params()
new_params = current_params + self.config.learning_rate * gradient
self.set_flat_params(new_params)
# Compute statistics
successful = [r for r in results if r.success]
mean_rfree = (
np.mean([r.final_rfree for r in successful]) if successful else float("inf")
)
best_rfree = (
min([r.final_rfree for r in successful]) if successful else float("inf")
)
gen_result = GenerationResult(
generation=self.generation,
results=results,
mean_final_rfree=mean_rfree,
best_final_rfree=best_rfree,
gradient_norm=np.linalg.norm(gradient),
param_update_norm=np.linalg.norm(self.config.learning_rate * gradient),
)
self.update(gen_result)
return gen_result
[docs]
def save_checkpoint(self, path: str, state_dim: int = 31, hidden_dim: int = 256):
"""Save training checkpoint."""
checkpoint = {
"policy_state_dict": self.policy.state_dict(),
"config": self.config.__dict__,
"generation": self.generation,
"n_params": self.n_params,
"state_dim": state_dim,
"hidden_dim": hidden_dim,
"history": [
{
"generation": h.generation,
"mean_final_rfree": h.mean_final_rfree,
"best_final_rfree": h.best_final_rfree,
"gradient_norm": h.gradient_norm,
}
for h in self.history
],
}
torch.save(checkpoint, path)
[docs]
def load_checkpoint(self, path: str):
"""Load training checkpoint."""
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy.load_state_dict(checkpoint["policy_state_dict"])
self.generation = checkpoint.get("generation", 0)
# Config and history are informational
[docs]
def create_policy_network(state_dim: int = 31, hidden_dim: int = 256) -> nn.Module:
"""Create the standard policy network architecture.
Uses 31-dimensional state vector matching meta_weighting_4 format:
- Progress (2): progress, improvement_rate
- R-factors (4): rwork, rfree, gap, delta_rfree
- Static features (7): resolution_min, inv_resolution, log_n_atoms,
log_n_hkl, data_to_param_ratio, log_wilson_b, adp_cv
- X-ray losses (3): work, test, work/test ratio
- Geometry losses (7): total, bond, angle, torsion, planarity, chiral, nonbonded
- Geometry RMSD (2): bond_rmsd, angle_rmsd
- ADP losses (4): total, simu, locality, KL
- ADP stats (2): mean_adp/wilson_b, adp_std/wilson_b
"""
class PolicyNetwork(nn.Module):
def __init__(self):
super().__init__()
self.input_norm = nn.LayerNorm(state_dim)
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
self.relu = nn.ReLU()
self.fc_out = nn.Linear(hidden_dim, 20) # 10 log-weights + 10 log-sigmas
self.tanh = nn.Tanh()
self.log_weight_scale = 3.0
self.log_sigma_scale = 2.0
def forward(self, x):
x = self.input_norm(x)
x = self.relu(self.ln1(self.fc1(x)))
x = self.tanh(self.fc_out(x))
log_weights = x[:, :10] * self.log_weight_scale
log_sigmas = x[:, 10:] * self.log_sigma_scale
return log_weights, log_sigmas
return PolicyNetwork()
[docs]
def initialize_policy_sensibly(state_dim: int = 31, hidden_dim: int = 256) -> nn.Module:
"""
Initialize policy with sensible default weights.
Uses 31-dimensional state vector matching meta_weighting_4 format.
Based on correlation analysis:
- Higher xray weight is better
- Lower torsion/KL weights are better
"""
policy = create_policy_network(state_dim=state_dim, hidden_dim=hidden_dim)
# The output layer maps tanh output ([-1, 1]) to log-weights ([-3, 3])
# Initialize bias to produce reasonable default weights:
# xray: high (~5-10) -> log = 1.6-2.3 -> tanh input = 0.5-0.8
# torsion, KL: low (~0.3) -> log = -1.2 -> tanh input = -0.4
# others: medium (~1-3) -> log = 0-1 -> tanh input = 0-0.3
with torch.no_grad():
# Set output bias for log-weights (first 10 outputs)
default_log_weights = torch.tensor(
[
1.5, # xray: high
0.5, # bond: medium
0.0, # angle: medium
-0.5, # torsion: low
0.0, # planarity: medium
0.0, # chiral: medium
0.0, # nonbonded: medium
0.5, # simu: medium
0.0, # locality: medium
-0.5, # KL: low
]
)
# Convert to tanh input space
tanh_targets = default_log_weights / 3.0 # Divide by log_weight_scale
# Set bias (assuming weights are near zero initially)
policy.fc_out.bias.data[:10] = torch.atanh(tanh_targets.clamp(-0.95, 0.95))
# Set sigma biases to produce moderate uncertainty
policy.fc_out.bias.data[10:] = 0.0 # log_sigma = 0 -> sigma = 1
return policy