Source code for torchref.refinement.optimizers.internal_coord_sa

"""
Internal Coordinate Simulated Annealing optimizer.

Implements batch Metropolis SA for internal coordinates with:
- Auto-calibrated temperature from energy landscape sampling
- Single scale parameter controlling perturbation magnitude
- Best structure tracking throughout optimization
- Torsion wrapping to [-pi, pi]

Usage:
    model.use_internal_coordinates(n_aa_per_segment=3)
    ic = model.xyz

    # Torsion-only refinement (recommended for conformational exploration):
    state, best_loss = optimize_internal_coord_sa(
        state=state,
        params=[ic.torsions],
        scale=1.0,
        ...
    )

    # All internal coordinates:
    state, best_loss = optimize_internal_coord_sa(
        state=state,
        params=[ic.torsions, ic.bond_lengths, ic.angles,
                ic.segment_positions, ic.segment_orientations],
        scale=1.0,
        ...
    )
"""

import math
from typing import Callable, Dict, List, Optional, Tuple, Any

import torch

from torchref.refinement.loss_state import LossState


def _is_angle_like(param: torch.Tensor) -> bool:
    """
    Heuristic to detect if a parameter tensor contains angle-like values.

    Angles (torsions, bond angles) typically have values in [-pi, pi] or [0, pi].
    Position-like parameters (segment_positions) have larger magnitudes.
    """
    if param.numel() == 0:
        return False
    with torch.no_grad():
        abs_max = param.abs().max().item()
        # Angles are bounded by pi (~3.14), positions are typically > 10 Angstrom
        return abs_max < 10.0


def _wrap_torsions(param: torch.Tensor) -> torch.Tensor:
    """Wrap angle values to [-pi, pi]."""
    return torch.atan2(torch.sin(param), torch.cos(param))


def _calibrate_temperature(
    state: LossState,
    params: List[torch.Tensor],
    scale: float,
    n_samples: int = 20,
    target_acceptance: float = 0.5,
) -> float:
    """
    Auto-calibrate initial temperature based on energy landscape.

    Samples typical energy changes from random perturbations and sets
    T_initial so that the median delta_E gives target_acceptance probability.

    For Metropolis: P(accept) = exp(-delta_E / T)
    => T = -delta_E / ln(target_acceptance)
    """
    deltas = []

    with torch.no_grad():
        # Save original state
        saved_values = [p.data.clone() for p in params]
        current_loss = state.aggregate().item()

        for _ in range(n_samples):
            # Apply random perturbations
            for param in params:
                if _is_angle_like(param):
                    # Angle-like: scale * 0.1 rad
                    noise = torch.randn_like(param) * (scale * 0.1)
                else:
                    # Position-like: scale * 0.01 A
                    noise = torch.randn_like(param) * (scale * 0.01)
                param.data.add_(noise)

            # Compute energy change
            new_loss = state.aggregate().item()
            delta = new_loss - current_loss

            # Only consider positive deltas for calibration
            if delta > 0:
                deltas.append(delta)

            # Restore
            for param, saved in zip(params, saved_values):
                param.data.copy_(saved)

    if len(deltas) == 0:
        # All moves were downhill - use a small default
        return 1.0

    # Use median positive delta for robust calibration
    median_delta = sorted(deltas)[len(deltas) // 2]

    # T such that exp(-median_delta / T) = target_acceptance
    # => T = -median_delta / ln(target_acceptance)
    T = -median_delta / math.log(target_acceptance)

    return max(T, 0.01)  # Ensure positive temperature


[docs] def optimize_internal_coord_sa( state: LossState, params: List[torch.Tensor], scale: float = 1.0, T_initial: Optional[float] = None, T_final: float = 0.01, n_steps: int = 5000, cooling_schedule: str = "exponential", verbose: int = 0, callback: Optional[Callable] = None, ) -> Tuple[LossState, float]: """ Universal Simulated Annealing for internal coordinates. Perturbs the provided parameter tensors using batch Metropolis: all parameters are perturbed together, single accept/reject decision. Parameters ---------- state : LossState Configured loss state with targets. params : List[torch.Tensor] Parameter tensors to optimize. For internal coordinates: - Torsion-only: [ic.torsions] - All coords: [ic.torsions, ic.bond_lengths, ic.angles, ic.segment_positions, ic.segment_orientations] scale : float Master scale for perturbations. Default 1.0. - Angle-like params: scale * 0.1 rad (~6 degrees at scale=1.0) - Position-like params: scale * 0.01 A T_initial : float, optional Initial temperature. If None, auto-calibrated from energy landscape. T_final : float Final temperature. Default 0.01. n_steps : int Number of SA steps. Default 5000. cooling_schedule : str 'exponential' or 'linear'. Default 'exponential'. verbose : int Verbosity level. callback : callable, optional Called after each step: callback(step, T, loss, best_loss). Returns ------- state : LossState Updated state. best_loss : float Best (lowest) loss found during optimization. """ if len(params) == 0: raise ValueError("No parameters provided to optimize") # Ensure all params have requires_grad (for consistency) for p in params: if not isinstance(p, torch.Tensor): raise TypeError(f"params must be torch.Tensor, got {type(p)}") # Identify param types (angle-like vs position-like) is_angle = [_is_angle_like(p) for p in params] if verbose > 0: print("Internal Coord SA Parameter Summary:") total_params = 0 for i, p in enumerate(params): n = p.numel() total_params += n ptype = "angle" if is_angle[i] else "position" pert = scale * 0.1 if is_angle[i] else scale * 0.01 print(f" Param {i}: {n:6d} params ({ptype:8s}), perturbation={pert:.4f}") print(f" Total: {total_params} parameters") # Auto-calibrate temperature if not provided if T_initial is None: T_initial = _calibrate_temperature(state, params, scale) if verbose > 0: print(f" Auto-calibrated T_initial: {T_initial:.4f}") # Cooling rate if cooling_schedule == "exponential": cooling_rate = (T_final / T_initial) ** (1.0 / n_steps) else: cooling_rate = None # Initialize with torch.no_grad(): current_loss = state.aggregate().item() best_loss = current_loss best_params = [p.data.clone() for p in params] T = T_initial n_accepted = 0 n_proposed = 0 for step in range(n_steps): # Update temperature if cooling_schedule == "exponential": T = T_initial * (cooling_rate ** step) else: T = T_initial - (T_initial - T_final) * (step / n_steps) n_proposed += 1 # Save current values with torch.no_grad(): saved_values = [p.data.clone() for p in params] # Apply perturbations for i, param in enumerate(params): if is_angle[i]: noise = torch.randn_like(param) * (scale * 0.1) else: noise = torch.randn_like(param) * (scale * 0.01) param.data.add_(noise) # Wrap torsions to [-pi, pi] if is_angle[i]: param.data.copy_(_wrap_torsions(param.data)) # Evaluate new loss new_loss = state.aggregate().item() delta_E = new_loss - current_loss # Metropolis criterion accept = False if delta_E < 0: accept = True elif T > 0 and torch.rand(1).item() < math.exp(-delta_E / T): accept = True if accept: current_loss = new_loss n_accepted += 1 # Track best if current_loss < best_loss: best_loss = current_loss with torch.no_grad(): best_params = [p.data.clone() for p in params] else: # Reject - restore with torch.no_grad(): for param, saved in zip(params, saved_values): param.data.copy_(saved) # Progress logging if verbose > 0 and (step + 1) % max(1, n_steps // 10) == 0: accept_rate = n_accepted / max(1, n_proposed) print(f"Step {step+1}/{n_steps}, T={T:.4f}, Loss={current_loss:.4f}, " f"Best={best_loss:.4f}, Accept={accept_rate:.1%}") # Callback if callback is not None: callback(step, T, current_loss, best_loss) # Restore best parameters with torch.no_grad(): for param, best in zip(params, best_params): param.data.copy_(best) if verbose > 0: final_accept = n_accepted / max(1, n_proposed) print(f"\nInternal Coord SA Complete:") print(f" Final loss: {best_loss:.4f}") print(f" Acceptance rate: {final_accept:.1%}") return state, best_loss
def _calibrate_temperature_gradient( loss_fn: Callable[[], torch.Tensor], params: List[torch.Tensor], scale: float, n_samples: int = 20, target_acceptance: float = 0.5, ) -> float: """ Auto-calibrate temperature for gradient-based SA. Samples typical |gradient · perturbation| values and sets T_initial so that the median gives target_acceptance probability. """ deltas = [] for _ in range(n_samples): # Compute gradient for p in params: if p.grad is not None: p.grad.zero_() loss = loss_fn() loss.backward() # Generate random perturbations and compute g · δ total_delta = 0.0 for p in params: if p.grad is not None: if _is_angle_like(p): delta = torch.randn_like(p) * (scale * 0.1) else: delta = torch.randn_like(p) * (scale * 0.01) # g · δ for this parameter total_delta += (p.grad * delta).sum().item() if total_delta > 0: deltas.append(total_delta) if len(deltas) == 0: return 1.0 median_delta = sorted(deltas)[len(deltas) // 2] T = -median_delta / math.log(target_acceptance) return max(T, 0.01)
[docs] def optimize_gradient_sa( loss_fn: Callable[[], torch.Tensor], params: List[torch.Tensor], scale: float = 1.0, T_initial: Optional[float] = None, T_final: float = 0.01, n_steps: int = 5000, cooling_schedule: str = "exponential", per_parameter: bool = True, verbose: int = 0, callback: Optional[Callable] = None, ) -> Tuple[float, float]: """ Gradient-based Simulated Annealing. Uses gradient information for Metropolis acceptance criterion: - ΔE ≈ gradient · perturbation (first-order Taylor approximation) - Accept if g·δ < 0 (moved downhill) - Accept with probability exp(-g·δ / T) if g·δ > 0 Advantages over loss-based SA: - One gradient computation per step (vs one loss eval per perturbation) - Per-parameter acceptance allows partial moves - Gradient guides acceptance decisions Parameters ---------- loss_fn : callable Loss function returning a differentiable scalar tensor. params : List[torch.Tensor] Parameter tensors to optimize (must have requires_grad=True). scale : float Perturbation scale. Default 1.0. - Angle-like: scale * 0.1 rad - Position-like: scale * 0.01 A T_initial : float, optional Initial temperature. Auto-calibrated if None. T_final : float Final temperature. Default 0.01. n_steps : int Number of SA steps. Default 5000. cooling_schedule : str 'exponential' or 'linear'. Default 'exponential'. per_parameter : bool If True, accept/reject each parameter independently. If False, use batch acceptance (sum of g·δ for all params). Default True. verbose : int Verbosity level. callback : callable, optional Called after each step: callback(step, T, loss, accept_rate). Returns ------- best_loss : float Best loss found during optimization. final_loss : float Loss at the end of optimization. """ if len(params) == 0: raise ValueError("No parameters provided") # Ensure requires_grad for p in params: if not p.requires_grad: p.requires_grad_(True) # Identify param types is_angle = [_is_angle_like(p) for p in params] if verbose > 0: print("Gradient-based SA Parameter Summary:") total_params = 0 for i, p in enumerate(params): n = p.numel() total_params += n ptype = "angle" if is_angle[i] else "position" pert = scale * 0.1 if is_angle[i] else scale * 0.01 print(f" Param {i}: {n:6d} params ({ptype:8s}), perturbation={pert:.4f}") print(f" Total: {total_params} parameters") print(f" Mode: {'per-parameter' if per_parameter else 'batch'}") # Auto-calibrate temperature if T_initial is None: T_initial = _calibrate_temperature_gradient(loss_fn, params, scale) if verbose > 0: print(f" Auto-calibrated T_initial: {T_initial:.4f}") # Cooling rate if cooling_schedule == "exponential": cooling_rate = (T_final / T_initial) ** (1.0 / n_steps) # Initial loss with torch.no_grad(): current_loss = loss_fn().item() best_loss = current_loss best_params = [p.data.clone() for p in params] T = T_initial total_accepted = 0 total_proposed = 0 for step in range(n_steps): # Update temperature if cooling_schedule == "exponential": T = T_initial * (cooling_rate ** step) else: T = T_initial - (T_initial - T_final) * (step / n_steps) # Zero gradients and compute loss for p in params: if p.grad is not None: p.grad.zero_() loss = loss_fn() loss.backward() # Generate perturbations and apply based on gradient step_accepted = 0 step_proposed = 0 if per_parameter: # Per-parameter acceptance with torch.no_grad(): for i, param in enumerate(params): if param.grad is None: continue # Generate perturbation if is_angle[i]: delta = torch.randn_like(param) * (scale * 0.1) else: delta = torch.randn_like(param) * (scale * 0.01) # Compute g · δ for each element g_dot_delta = param.grad * delta # Metropolis criterion per element # Accept if g·δ < 0 (downhill) or with prob exp(-g·δ/T) accept_prob = torch.where( g_dot_delta < 0, torch.ones_like(g_dot_delta), torch.exp(-g_dot_delta / T) ) accept_mask = torch.rand_like(accept_prob) < accept_prob # Apply accepted perturbations param.data.add_(delta * accept_mask.float()) # Wrap angles if is_angle[i]: param.data.copy_(_wrap_torsions(param.data)) step_accepted += accept_mask.sum().item() step_proposed += param.numel() else: # Batch acceptance: sum of g·δ for all params with torch.no_grad(): deltas = [] total_g_dot_delta = 0.0 for i, param in enumerate(params): if param.grad is None: deltas.append(None) continue if is_angle[i]: delta = torch.randn_like(param) * (scale * 0.1) else: delta = torch.randn_like(param) * (scale * 0.01) deltas.append(delta) total_g_dot_delta += (param.grad * delta).sum().item() # Single Metropolis decision step_proposed = 1 accept = False if total_g_dot_delta < 0: accept = True elif T > 0 and torch.rand(1).item() < math.exp(-total_g_dot_delta / T): accept = True if accept: step_accepted = 1 for i, (param, delta) in enumerate(zip(params, deltas)): if delta is not None: param.data.add_(delta) if is_angle[i]: param.data.copy_(_wrap_torsions(param.data)) total_accepted += step_accepted total_proposed += step_proposed # Evaluate actual loss periodically (for tracking, not for acceptance) if (step + 1) % max(1, n_steps // 20) == 0 or step == n_steps - 1: with torch.no_grad(): current_loss = loss_fn().item() if current_loss < best_loss: best_loss = current_loss best_params = [p.data.clone() for p in params] # Progress logging if verbose > 0 and (step + 1) % max(1, n_steps // 10) == 0: accept_rate = total_accepted / max(1, total_proposed) with torch.no_grad(): current_loss = loss_fn().item() print(f"Step {step+1}/{n_steps}, T={T:.4f}, Loss={current_loss:.4f}, " f"Best={best_loss:.4f}, Accept={accept_rate:.1%}") # Callback if callback is not None: accept_rate = total_accepted / max(1, total_proposed) callback(step, T, current_loss, accept_rate) # Restore best parameters with torch.no_grad(): for param, best in zip(params, best_params): param.data.copy_(best) final_loss = loss_fn().item() if verbose > 0: final_accept = total_accepted / max(1, total_proposed) print(f"\nGradient SA Complete:") print(f" Best loss: {best_loss:.4f}") print(f" Final loss: {final_loss:.4f}") print(f" Overall acceptance: {final_accept:.1%}") return best_loss, final_loss
[docs] def refine_sa_lbfgs( model, loss_fn: Callable[[], torch.Tensor], n_sa_steps: int = 5000, sa_scale: float = 1.0, sa_T_initial: Optional[float] = None, sa_T_final: float = 0.01, sa_params: Optional[List[str]] = None, n_lbfgs_cycles: int = 20, lbfgs_lr: float = 1.0, lbfgs_max_iter: int = 20, verbose: int = 1, sa_callback: Optional[Callable] = None, lbfgs_callback: Optional[Callable] = None, ) -> Dict[str, Any]: """ Combined SA + LBFGS pipeline for internal coordinate refinement. Pipeline: 1. SA phase: Global exploration using internal coordinates (torsions by default) 2. LBFGS phase: Local optimization to refine best structure Parameters ---------- model : ModelFT Model with internal coordinates enabled (model.use_internal_coordinates()). loss_fn : callable Loss function that returns a scalar tensor: loss = loss_fn() n_sa_steps : int Number of SA steps. Default 5000. sa_scale : float SA perturbation scale. Default 1.0. sa_T_initial : float, optional Initial SA temperature. Auto-calibrated if None. sa_T_final : float Final SA temperature. Default 0.01. sa_params : List[str], optional Which internal coord params to use for SA. Options: 'torsions', 'bond_lengths', 'angles', 'segment_positions', 'segment_orientations' Default: ['torsions'] (torsion-only exploration) n_lbfgs_cycles : int Number of LBFGS optimization cycles. Default 20. lbfgs_lr : float LBFGS learning rate. Default 1.0. lbfgs_max_iter : int Max iterations per LBFGS cycle. Default 20. verbose : int Verbosity level. Default 1. sa_callback : callable, optional Callback for SA phase: callback(step, T, loss, best_loss) lbfgs_callback : callable, optional Callback for LBFGS phase: callback(cycle, loss) Returns ------- dict with keys: 'initial_loss': Loss before any optimization 'sa_loss': Loss after SA phase 'final_loss': Loss after LBFGS phase 'sa_history': SA history if callback provided 'lbfgs_history': LBFGS history if callback provided """ # Check model has internal coordinates ic = model.xyz if not hasattr(ic, 'torsions'): raise ValueError( "Model must have internal coordinates enabled. " "Call model.use_internal_coordinates() first." ) # Default: torsion-only SA if sa_params is None: sa_params = ['torsions'] # Build parameter list param_map = { 'torsions': ic.torsions, 'bond_lengths': ic.bond_lengths, 'angles': ic.angles, 'segment_positions': ic.segment_positions, 'segment_orientations': ic.segment_orientations, } params = [] for name in sa_params: if name not in param_map: raise ValueError(f"Unknown param type: {name}. " f"Options: {list(param_map.keys())}") p = param_map[name] if p.numel() > 0: params.append(p) if len(params) == 0: raise ValueError("No parameters selected for SA optimization") # Setup LossState for SA state = LossState(device=next(model.parameters()).device) state.register_target("combined", loss_fn) # Initial loss with torch.no_grad(): initial_loss = loss_fn().item() if verbose > 0: print("=" * 60) print("SA + LBFGS Refinement Pipeline") print("=" * 60) print(f"Initial loss: {initial_loss:.4f}") print(f"\nSA Phase: {n_sa_steps} steps, scale={sa_scale}") print(f" Optimizing: {sa_params}") # ============================================ # Phase 1: Simulated Annealing # ============================================ sa_history = [] def _sa_callback(step, T, loss, best_loss): sa_history.append({'step': step, 'T': T, 'loss': loss, 'best': best_loss}) if sa_callback is not None: sa_callback(step, T, loss, best_loss) state, sa_best_loss = optimize_internal_coord_sa( state=state, params=params, scale=sa_scale, T_initial=sa_T_initial, T_final=sa_T_final, n_steps=n_sa_steps, verbose=verbose, callback=_sa_callback, ) if verbose > 0: print(f"\nSA Complete: {initial_loss:.4f} -> {sa_best_loss:.4f}") # ============================================ # Phase 2: LBFGS Local Optimization # ============================================ if verbose > 0: print(f"\nLBFGS Phase: {n_lbfgs_cycles} cycles") # Use all internal coord parameters for LBFGS all_params = [p for p in param_map.values() if p.numel() > 0] optimizer = torch.optim.LBFGS( all_params, lr=lbfgs_lr, max_iter=lbfgs_max_iter, history_size=5, line_search_fn="strong_wolfe", ) lbfgs_history = [] def closure(): optimizer.zero_grad() loss = loss_fn() loss.backward() return loss for cycle in range(n_lbfgs_cycles): optimizer.step(closure) with torch.no_grad(): current_loss = loss_fn().item() lbfgs_history.append({'cycle': cycle, 'loss': current_loss}) if lbfgs_callback is not None: lbfgs_callback(cycle, current_loss) if verbose > 1: print(f" LBFGS Cycle {cycle+1}/{n_lbfgs_cycles}: Loss={current_loss:.4f}") with torch.no_grad(): final_loss = loss_fn().item() if verbose > 0: print(f"LBFGS Complete: {sa_best_loss:.4f} -> {final_loss:.4f}") print(f"\n{'='*60}") print(f"Total improvement: {initial_loss:.4f} -> {final_loss:.4f}") improvement = (initial_loss - final_loss) / initial_loss * 100 print(f" ({improvement:.1f}% reduction)") print("=" * 60) return { 'initial_loss': initial_loss, 'sa_loss': sa_best_loss, 'final_loss': final_loss, 'sa_history': sa_history, 'lbfgs_history': lbfgs_history, }