"""
Policy-based component weighting for crystallographic refinement.
This module implements a weighting scheme that uses a trained neural network
policy to predict component weights from the current refinement state.
Supports both:
- Training mode: Sample from policy distribution for exploration
- Evaluation mode: Use mean predictions for deterministic refinement
Integrates with LossState for hierarchical weight management.
"""
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torchref.config import get_default_device, get_float_dtype
from torchref.refinement.weighting.base_weighting import BaseWeighting
from torchref.utils.stats import (
VERBOSITY_DEBUG,
VERBOSITY_STANDARD,
stat,
)
if TYPE_CHECKING:
from torchref.refinement.loss_state import LossState
# Component names (must match policy network output order)
COMPONENTS = [
"xray",
"bond",
"angle",
"torsion",
"planarity",
"chiral",
"nonbonded",
"simu",
"locality",
"KL",
]
# Mapping from policy component names to LossState hierarchical names
COMPONENT_TO_LOSS_STATE = {
"xray": "xray",
"bond": "geometry/bond",
"angle": "geometry/angle",
"torsion": "geometry/torsion",
"planarity": "geometry/planarity",
"chiral": "geometry/chiral",
"nonbonded": "geometry/nonbonded",
"simu": "adp/simu",
"locality": "adp/locality",
"KL": "adp/KL",
}
# Reverse mapping
LOSS_STATE_TO_COMPONENT = {v: k for k, v in COMPONENT_TO_LOSS_STATE.items()}
[docs]
@dataclass
class StepState:
"""State representation at a single refinement step for trajectory recording.
Matches the 31-feature format used in meta_weighting_4.
"""
step: int
progress: float # step / n_steps (normalized)
improvement_rate: float # (initial_rfree - rfree) / initial_rfree
rwork: float
rfree: float
rfree_gap: float # rfree - rwork (overfitting indicator)
delta_rfree: float
xray_loss: float
xray_loss_test: float
xray_work_test_ratio: float # xray_work / xray_test
geometry_loss: float
adp_loss: float
bond_rmsd: float
angle_rmsd: float
mean_adp_normalized: float # mean_adp / wilson_b
adp_std_normalized: float # adp_std / wilson_b
component_losses: Dict[str, float] = field(default_factory=dict)
component_weights: Dict[str, float] = field(default_factory=dict)
[docs]
@dataclass
class StepRecord:
"""Record for a single refinement step (state-action-reward tuple)."""
state: StepState
log_weights: Dict[str, float] # Actions in log-space
weights: Dict[str, float] # Effective weights used
reward: float # Per-step reward (-delta_rfree)
[docs]
@dataclass
class TrajectoryData:
"""Complete trajectory data for training."""
pdb_id: str
structure_path: str
sf_path: str
steps: List[StepRecord] = field(default_factory=list)
initial_rfree: float = 0.0
final_rfree: float = 0.0
initial_rwork: float = 0.0
final_rwork: float = 0.0
total_time: float = 0.0
random_seed: Optional[int] = None
policy_version: Optional[str] = None
success: bool = True
error_message: Optional[str] = None
[docs]
class PolicyComponentWeighting(BaseWeighting):
"""
Policy-based component weighting using a trained neural network.
This weighting scheme uses a policy network to predict component weights
from the current refinement state. Supports both training (sampling) and
evaluation (deterministic) modes.
Receives all data through LossState.meta and state._losses.
Parameters
----------
device : torch.device, optional
Computation device.
policy_path : str, optional
Path to trained policy checkpoint (.pt file). If None, uses random init.
sample : bool, optional
Whether to sample from policy distribution (default: False for eval)
temperature : float, optional
Sampling temperature for exploration (default: 1.0)
- 0.0: deterministic (use mean predictions)
- > 0.0: sample from distribution with scaled variance
n_steps : int, optional
Total number of steps in refinement cycle (default: 10).
verbose : int, optional
Verbosity level (default: 0).
Attributes
----------
policy : nn.Module
Policy network (loaded or randomly initialized)
sample : bool
Whether to sample from predicted distribution
temperature : float
Sampling temperature for exploration
current_step : int
Current refinement step index
last_log_weights : dict
Most recent log-space weights (actions for training)
last_weights : dict
Most recent linear-space weights
trajectory : TrajectoryData
Current trajectory being recorded (if recording enabled)
Examples
--------
::
# Evaluation mode (deterministic)
policy = PolicyComponentWeighting(device=device, policy_path='policy.pt', sample=False)
# Training mode (sampling for exploration)
policy = PolicyComponentWeighting(device=device, policy_path='policy.pt', sample=True, temperature=1.0)
# Use with LossState
state = refinement.create_loss_state()
state = policy.forward(state)
"""
name = "policy_weighting"
[docs]
def __init__(
self,
device: torch.device = get_default_device(),
policy_path: Optional[str] = None,
sample: bool = False,
temperature: float = 1.0,
n_steps: int = 10,
verbose: int = 0,
):
super().__init__(device)
# Sampling parameters
self.sample = sample
self.temperature = temperature
self.n_steps = n_steps
self.verbose = verbose
# Load or create policy network
self.policy = self._create_policy_network()
if policy_path is not None:
self._load_policy(policy_path)
self.policy.to(self.device)
self.policy.eval()
# State tracking
self.current_step = 0
self.prev_rfree = None
self.initial_rfree = None
# Last predictions (for trajectory recording)
self.last_log_weights: Dict[str, float] = {}
self.last_weights: Dict[str, float] = {}
self.last_log_sigmas: Dict[str, float] = {}
# Trajectory recording
self.trajectory: Optional[TrajectoryData] = None
self._recording = False
# For backward compatibility
self.predicted_weights = self.last_weights
self.predicted_uncertainties = {}
# Static features - computed lazily from first state.meta
self.static_features: Optional[Dict[str, float]] = None
if verbose > 0:
param_count = sum(p.numel() for p in self.policy.parameters())
print("PolicyComponentWeighting initialized")
print(f" Policy: {policy_path or 'random initialization'}")
print(f" Parameters: {param_count:,}")
print(f" Sample mode: {sample}")
print(f" Temperature: {temperature}")
print(f" N steps: {n_steps}")
def _create_policy_network(self) -> nn.Module:
"""Create the policy network architecture.
Uses 31-dimensional state vector matching meta_weighting_4 format.
"""
class PolicyNetwork(nn.Module):
def __init__(self, state_dim=31, hidden_dim=256):
super().__init__()
self.input_norm = nn.LayerNorm(state_dim)
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)
self.relu = nn.ReLU()
self.fc_out = nn.Linear(hidden_dim, 20) # 10 weights + 10 sigmas
self.tanh = nn.Tanh()
self.log_weight_scale = 3.0
self.log_sigma_scale = 2.0
def forward(self, x):
x = self.input_norm(x)
x = self.fc1(x)
x = self.ln1(x)
x = self.relu(x)
x = self.fc_out(x)
x = self.tanh(x)
log_weights = x[:, :10] * self.log_weight_scale
log_sigmas = x[:, 10:] * self.log_sigma_scale
return log_weights, log_sigmas
return PolicyNetwork()
def _compute_static_features(self, state: "LossState"):
"""Compute static features from state.meta (computed once on first forward).
Static features (7 total):
- resolution_min: Best resolution in Angstroms
- inv_resolution: 1/resolution (higher = better data)
- log_n_atoms: Log of number of atoms (structure size)
- log_n_hkl: Log of number of reflections
- data_to_param_ratio: n_hkl / n_atoms
- log_wilson_b: Log of Wilson B-factor (data quality)
- b_cv: Coefficient of variation of B-factors
- wilson_b: Wilson B (kept for normalization)
"""
if self.static_features is not None:
return # Already computed
# Get static data from state.meta
n_atoms = state.get("n_atoms", 1)
n_hkl = state.get("n_hkl", 1)
resolution_min = state.get("resolution_min", 2.0)
wilson_b = state.get("wilson_b", 20.0)
mean_adp = state.get("mean_adp", 20.0)
adp_std = state.get("adp_std", 5.0)
self.static_features = {
"resolution_min": resolution_min,
"inv_resolution": 1.0 / resolution_min,
"log_n_atoms": np.log(n_atoms),
"log_n_hkl": np.log(n_hkl),
"data_to_param_ratio": n_hkl / n_atoms,
"log_wilson_b": np.log(wilson_b),
"adp_cv": adp_std / (mean_adp + 1e-6), # Coefficient of variation
"wilson_b": wilson_b, # Keep for B-factor normalization
}
def _load_policy(self, checkpoint_path: str):
"""Load policy weights from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if "policy_state_dict" in checkpoint:
self.policy.load_state_dict(checkpoint["policy_state_dict"])
else:
self.policy.load_state_dict(checkpoint)
self.policy.eval()
[docs]
def set_training_mode(self, sample: bool = True, temperature: float = 1.0):
"""Enable training mode (sampling from policy distribution)."""
self.sample = sample
self.temperature = temperature
[docs]
def set_eval_mode(self):
"""Enable evaluation mode (deterministic predictions)."""
self.sample = False
[docs]
def reset_trajectory(
self, pdb_id: str = "", structure_path: str = "", sf_path: str = ""
):
"""Reset for a new trajectory and optionally start recording."""
self.current_step = 0
self.prev_rfree = None
self.initial_rfree = None
self.last_log_weights = {}
self.last_weights = {}
# Capture initial R-free for improvement_rate calculation
with torch.no_grad():
_, rfree = self.refinement.get_rfactor()
self.initial_rfree = float(rfree)
[docs]
def start_recording(
self,
pdb_id: str,
structure_path: str,
sf_path: str,
seed: Optional[int] = None,
policy_version: Optional[str] = None,
):
"""Start recording a new trajectory."""
self.reset_trajectory()
self.trajectory = TrajectoryData(
pdb_id=pdb_id,
structure_path=structure_path,
sf_path=sf_path,
random_seed=seed,
policy_version=policy_version,
)
self._recording = True
# Record initial R-factors
with torch.no_grad():
rwork, rfree = self.refinement.get_rfactor()
self.trajectory.initial_rwork = float(rwork)
self.trajectory.initial_rfree = float(rfree)
[docs]
def stop_recording(self) -> Optional[TrajectoryData]:
"""Stop recording and return the trajectory data."""
if not self._recording or self.trajectory is None:
return None
# Record final R-factors
with torch.no_grad():
rwork, rfree = self.refinement.get_rfactor()
self.trajectory.final_rwork = float(rwork)
self.trajectory.final_rfree = float(rfree)
self._recording = False
trajectory = self.trajectory
self.trajectory = None
return trajectory
[docs]
def increment_step(self):
"""Increment the step counter after a refinement step."""
self.current_step += 1
[docs]
def forward(self, state: "LossState") -> Dict[str, float]:
"""
Predict component weights from LossState.
Parameters
----------
state : LossState
LossState with meta and _losses populated.
Returns
-------
dict
Dictionary of {loss_state_name: weight} for all components
"""
# Compute static features on first call (lazy init)
self._compute_static_features(state)
# Initialize initial_rfree on first call
if self.initial_rfree is None:
self.initial_rfree = state.get("rfree", 0.25)
# Extract state features
features, step_state = self.extract_state_features(state)
features = features.unsqueeze(0) # Add batch dimension
# Get policy predictions
with torch.no_grad():
log_weights, log_sigmas = self.policy(features)
log_weights = log_weights[0] # Remove batch dimension
log_sigmas = log_sigmas[0]
# Store log-weights for trajectory recording
self.last_log_weights = {
comp: log_weights[i].item() for i, comp in enumerate(COMPONENTS)
}
self.last_log_sigmas = {
comp: log_sigmas[i].item() for i, comp in enumerate(COMPONENTS)
}
# Sample or use mean
if self.sample and self.temperature > 0:
# Sample from N(log_weight, temperature * exp(log_sigma))
sigmas = torch.exp(log_sigmas) * self.temperature
sampled_log_weights = torch.normal(log_weights, sigmas)
weights_tensor = torch.exp(sampled_log_weights)
else:
# Deterministic: use mean
weights_tensor = torch.exp(log_weights)
# Clip to reasonable range
weights_tensor = torch.clamp(weights_tensor, 0.01, 100.0)
# Build output dictionaries
self.last_weights = {
comp: weights_tensor[i].item() for i, comp in enumerate(COMPONENTS)
}
# Return with LossState naming convention
weights = {}
for comp in COMPONENTS:
loss_state_name = COMPONENT_TO_LOSS_STATE[comp]
weights[loss_state_name] = self.last_weights[comp]
# Update step_state with weights for recording
step_state.component_weights = self.last_weights.copy()
# Record step if trajectory recording is enabled
if self._recording and self.trajectory is not None:
reward = (
-step_state.delta_rfree
) # Negative delta = improvement = positive reward
reward = max(-0.05, min(0.05, reward)) # Clip to [-0.05, 0.05]
step_record = StepRecord(
state=step_state,
log_weights=self.last_log_weights.copy(),
weights=self.last_weights.copy(),
reward=reward,
)
self.trajectory.steps.append(step_record)
# For backward compatibility
self.predicted_weights = self.last_weights
self.predicted_uncertainties = {
comp: np.exp(self.last_log_sigmas[comp]) for comp in COMPONENTS
}
return weights
[docs]
def compute_weights(self, state: "LossState") -> Dict[str, float]:
"""
Compute component weights from LossState.
Deprecated: just call the method forward() or the instance directly.
Parameters
----------
state : LossState
LossState with meta and _losses populated.
Returns
-------
dict
Dictionary of {loss_state_name: weight} for all components
"""
return self.forward(state)
[docs]
def apply_to_state(self, state: "LossState") -> "LossState":
"""
Apply policy-predicted weights to a LossState.
This is the main integration point with the LossState architecture.
Call this after creating a LossState to apply policy weights.
Parameters
----------
state : LossState
Loss state to update with policy weights.
Returns
-------
LossState
State with policy weights applied.
Example
-------
::
state = refinement.create_loss_state()
state = policy_weighting.apply_to_state(state)
loss = state.aggregate()
"""
weights = self.forward(state)
for name, weight in weights.items():
state.set_weight(name, weight)
return state
[docs]
def stats(self, state: "LossState" = None) -> Dict[str, Any]:
"""
Return statistics for reporting.
Returns
-------
dict
Statistics dictionary with predicted weights and uncertainties
"""
stats_dict = {}
# Predicted weights
if self.last_weights:
for comp, weight in self.last_weights.items():
stats_dict[f"policy_weight_{comp}"] = stat(weight, VERBOSITY_STANDARD)
# Predicted log-sigmas (uncertainties)
if self.last_log_sigmas:
for comp, log_sigma in self.last_log_sigmas.items():
stats_dict[f"policy_sigma_{comp}"] = stat(
np.exp(log_sigma), VERBOSITY_DEBUG
)
# Sampling parameters
stats_dict["sample_mode"] = stat(self.sample, VERBOSITY_DEBUG)
stats_dict["temperature"] = stat(self.temperature, VERBOSITY_DEBUG)
stats_dict["current_step"] = stat(self.current_step, VERBOSITY_DEBUG)
return stats_dict
[docs]
def trajectory_to_dict(trajectory: TrajectoryData) -> Dict[str, Any]:
"""Convert TrajectoryData to a JSON-serializable dictionary."""
return {
"pdb_id": trajectory.pdb_id,
"structure_path": trajectory.structure_path,
"sf_path": trajectory.sf_path,
"initial_rfree": trajectory.initial_rfree,
"final_rfree": trajectory.final_rfree,
"initial_rwork": trajectory.initial_rwork,
"final_rwork": trajectory.final_rwork,
"total_time": trajectory.total_time,
"random_seed": trajectory.random_seed,
"policy_version": trajectory.policy_version,
"success": trajectory.success,
"error_message": trajectory.error_message,
"steps": [
{
"state": asdict(step.state),
"log_weights": step.log_weights,
"weights": step.weights,
"reward": step.reward,
}
for step in trajectory.steps
],
}
__all__ = [
"PolicyComponentWeighting",
"StepState",
"StepRecord",
"TrajectoryData",
"trajectory_to_dict",
"COMPONENTS",
"COMPONENT_TO_LOSS_STATE",
"LOSS_STATE_TO_COMPONENT",
]