Source code for torchref.refinement.optimizers.simulated_annealing

"""
Simulated Annealing optimizer for crystallographic refinement.

Implements a per-parameter Metropolis-Hastings acceptance criterion
for fine-grained exploration of parameter space.
"""

import math
from typing import Dict, List, Literal, Union

import torch

from torchref.refinement.loss_state import LossState


[docs] def optimize_simulated_annealing( state: LossState, params: List[torch.Tensor], T_initial: float = 1.0, T_final: float = 0.01, n_steps: int = 1000, perturbation_scale: Union[float, List[float], Dict[int, float]] = 0.01, absolute_scale: bool = False, cooling_schedule: Literal["exponential", "linear"] = "exponential", verbose: int = 0, callback: callable = None, ) -> LossState: """ Run simulated annealing optimization on a LossState. Uses per-parameter Metropolis criterion where each parameter is independently accepted or rejected. Parameters ---------- state : LossState Configured loss state with targets and weights. params : list of torch.Tensor Parameters to optimize. T_initial : float Initial temperature. Default is 1.0. T_final : float Final temperature. Default is 0.01. n_steps : int Number of SA steps. Default is 1000. perturbation_scale : float, list, or dict Scale factor for perturbations. Can be: - float: uniform scale for all parameters - list: per-parameter scales (same length as params) - dict: mapping from parameter index to scale (missing indices use 0.01) If absolute_scale=False (default), this is multiplied by parameter magnitude. If absolute_scale=True, this is used directly as the standard deviation. Default is 0.01. absolute_scale : bool If True, perturbation_scale is used as absolute std dev for noise. If False (default), perturbation_scale is relative to parameter magnitude. cooling_schedule : str Cooling schedule: "exponential" or "linear". Default is "exponential". verbose : int Verbosity level. Default is 0. callback : callable, optional Function called after each step with signature callback(step, T, loss, params). Useful for collecting snapshots during optimization. Returns ------- LossState State with history containing before/after loss values. """ params = list(params) n_params = len(params) # Normalize perturbation_scale to a list if isinstance(perturbation_scale, (int, float)): scales = [float(perturbation_scale)] * n_params elif isinstance(perturbation_scale, list): if len(perturbation_scale) != n_params: raise ValueError( f"perturbation_scale list length ({len(perturbation_scale)}) " f"must match params length ({n_params})" ) scales = [float(s) for s in perturbation_scale] elif isinstance(perturbation_scale, dict): scales = [perturbation_scale.get(i, 0.01) for i in range(n_params)] else: raise TypeError( f"perturbation_scale must be float, list, or dict, got {type(perturbation_scale)}" ) # Log initial state state.aggregate(log_values=True) # Compute initial loss with torch.no_grad(): current_loss = state.aggregate().item() # Cooling rate for exponential schedule if cooling_schedule == "exponential": cooling_rate = (T_final / T_initial) ** (1.0 / n_steps) T = T_initial n_accepted = 0 n_total = 0 for step in range(n_steps): # Update temperature if cooling_schedule == "exponential": T = T_initial * (cooling_rate ** step) else: # linear T = T_initial - (T_initial - T_final) * (step / n_steps) # Per-parameter Metropolis steps for param_idx, param in enumerate(params): if not param.requires_grad: continue n_total += 1 scale = scales[param_idx] # Save current value saved = param.detach().clone() # Perturb parameter with torch.no_grad(): if absolute_scale: # Use scale directly as standard deviation noise_scale = scale else: # Scale relative to parameter magnitude noise_scale = scale * (param.abs().mean() + 1e-8) param.add_(torch.randn_like(param) * noise_scale) # Compute new loss with torch.no_grad(): new_loss = state.aggregate().item() delta_E = new_loss - current_loss # Metropolis criterion if delta_E < 0: # Accept - lower energy current_loss = new_loss n_accepted += 1 elif torch.rand(1).item() < math.exp(-delta_E / T): # Accept - probabilistic current_loss = new_loss n_accepted += 1 else: # Reject - restore parameter with torch.no_grad(): param.copy_(saved) # Progress logging if verbose > 0 and (step + 1) % max(1, n_steps // 10) == 0: accept_rate = n_accepted / max(1, n_total) print(f"Step {step+1}/{n_steps}, T={T:.4f}, " f"Loss={current_loss:.6f}, Accept={accept_rate:.2%}") # Callback for snapshots if callback is not None: callback(step, T, current_loss, params) # Log final state state.new_entry() state.aggregate(log_values=True) if verbose > 0: print(f"SA complete: {n_accepted}/{n_total} moves accepted " f"({n_accepted/max(1,n_total):.1%})") return state