import math
import torch
from torch.optim import Optimizer
[docs]
class LangevinSA(Optimizer):
"""BAOAB Langevin dynamics integrator with simulated annealing.
Implements the BAOAB splitting scheme (Leimkuhler & Matthews, 2013)
for gradient-guided exploration with thermodynamically correct noise.
One gradient evaluation per step via staggered B steps.
Adaptive masses from EMA of squared gradients provide automatic scale
invariance across all parameter types (xyz, B-factors, occupancies,
torsions, etc.).
Call :meth:`calibrate` before the main loop to probe parameter stiffness
and warm up the adaptive masses without moving the structure.
Args:
params: Iterable of parameters or param groups.
dt: Integration timestep.
friction: Friction coefficient gamma. Controls thermalization speed.
T_initial: Starting temperature.
T_final: Final temperature.
total_steps: Total number of annealing steps.
cooling_schedule: 'exponential' or 'linear'.
adaptive_masses: Use EMA of grad² as per-element masses.
mass_beta: EMA decay for adaptive masses.
mass_eps: Floor for adaptive masses (numerical stability).
gradient_clip: Optional max gradient norm (per-parameter).
max_step_size: Maximum displacement per element per full step.
Velocities are clamped so |v * dt| <= max_step_size.
"""
[docs]
def __init__(
self,
params,
dt=0.01,
friction=10.0,
T_initial=2500.0,
T_final=0.01,
total_steps=1000,
cooling_schedule="exponential",
adaptive_masses=True,
mass_beta=0.999,
mass_eps=1e-8,
gradient_clip=None,
max_step_size=0.1,
):
if dt <= 0:
raise ValueError(f"dt must be positive, got {dt}")
if friction <= 0:
raise ValueError(f"friction must be positive, got {friction}")
if T_initial <= 0 or T_final <= 0:
raise ValueError("Temperatures must be positive")
if total_steps < 1:
raise ValueError(f"total_steps must be >= 1, got {total_steps}")
if cooling_schedule not in ("exponential", "linear"):
raise ValueError(
f"cooling_schedule must be 'exponential' or 'linear', "
f"got '{cooling_schedule}'"
)
defaults = dict(
dt=dt,
friction=friction,
T_initial=T_initial,
T_final=T_final,
total_steps=total_steps,
cooling_schedule=cooling_schedule,
adaptive_masses=adaptive_masses,
mass_beta=mass_beta,
mass_eps=mass_eps,
gradient_clip=gradient_clip,
max_step_size=max_step_size,
)
super().__init__(params, defaults)
self._current_step = 0
self._calibrated = False
# ------------------------------------------------------------------
# Temperature schedule
# ------------------------------------------------------------------
def _get_temperature(self):
"""Compute temperature at current step from the cooling schedule."""
group = self.param_groups[0]
T_i = group["T_initial"]
T_f = group["T_final"]
N = group["total_steps"]
t = min(self._current_step, N - 1) / max(N - 1, 1)
if group["cooling_schedule"] == "exponential":
log_ratio = math.log(T_f / T_i)
return T_i * math.exp(log_ratio * t)
else: # linear
return T_i + (T_f - T_i) * t
@property
def temperature(self):
"""Current temperature from the annealing schedule."""
return self._get_temperature()
@property
def current_step(self):
return self._current_step
@property
def total_steps(self):
return self.param_groups[0]["total_steps"]
@property
def kinetic_energy(self):
"""Sum of 0.5 * m * v^2 over all parameters (diagnostic)."""
ke = 0.0
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "velocity" not in state:
continue
v = state["velocity"]
m = state.get("mass")
if m is not None:
ke += 0.5 * (m * v * v).sum().item()
else:
ke += 0.5 * (v * v).sum().item()
return ke
# ------------------------------------------------------------------
# Calibration: probe stiffness then rollback
# ------------------------------------------------------------------
[docs]
@torch.no_grad()
def calibrate(self, closure, n_steps=10):
"""Probe parameter stiffness over n_steps, then rollback.
Runs small random perturbations to collect gradient statistics,
sets the adaptive masses from the observed grad², then restores
all parameters to their original values and initialises velocities
from Maxwell-Boltzmann with correctly scaled masses.
Args:
closure: Same closure as for ``step()`` — must zero_grad,
compute loss, call backward, and return loss.
n_steps: Number of probing steps.
"""
# --- snapshot ---
snapshots = {}
for group in self.param_groups:
for p in group["params"]:
snapshots[id(p)] = p.data.clone()
# --- accumulate grad² ---
grad_sq_sum = {}
n_valid = {}
for group in self.param_groups:
for p in group["params"]:
grad_sq_sum[id(p)] = torch.zeros_like(p.data)
n_valid[id(p)] = 0
for i in range(n_steps):
# Perturb via data assignment (not in-place add) so that
# data_ptr changes and CachedForwardMixin sees a cache miss.
for group in self.param_groups:
for p in group["params"]:
p.data = p.data + torch.randn_like(p.data) * 1e-3
with torch.enable_grad():
loss = closure()
if torch.isfinite(loss):
for group in self.param_groups:
for p in group["params"]:
if p.grad is not None:
grad_sq_sum[id(p)].add_(p.grad.detach() ** 2)
n_valid[id(p)] += 1
# Restore via assignment to change data_ptr (cache invalidation).
for group in self.param_groups:
for p in group["params"]:
p.data = snapshots[id(p)].clone()
# --- compute masses from calibration ---
T = self._get_temperature()
# First pass: compute raw masses and collect all nonzero elements
all_nonzero_masses = []
for group in self.param_groups:
eps = group["mass_eps"]
for p in group["params"]:
pid = id(p)
state = self.state[p]
n = max(n_valid[pid], 1)
avg_grad_sq = grad_sq_sum[pid] / n
state["grad_sq_avg"] = avg_grad_sq
m = avg_grad_sq.sqrt() + eps
state["mass"] = m
# Collect elements that actually got gradient signal
nonzero = m[avg_grad_sq > 0]
if nonzero.numel() > 0:
all_nonzero_masses.append(nonzero)
# Compute global median of informed masses → use as floor
if all_nonzero_masses:
median_mass = torch.cat(all_nonzero_masses).median().item()
else:
median_mass = 1.0
mass_floor = 0.01 * median_mass # 1% of median
# Second pass: clamp per-element masses and init velocities
n_clamped = 0
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
m = state["mass"]
below = m < mass_floor
n_clamped += below.sum().item()
m.clamp_(min=mass_floor)
# Maxwell-Boltzmann velocity with correct masses
state["velocity"] = torch.randn_like(p.data) * (T / m).sqrt()
state["prev_grad"] = None
self._calibrated = True
total_elements = sum(
p.numel() for g in self.param_groups for p in g["params"]
)
print(
f"LangevinSA calibrated over {n_steps} steps. "
f"Median mass={median_mass:.2e}, floor={mass_floor:.2e}, "
f"clamped {int(n_clamped)}/{total_elements} elements. "
f"Mass range per param:"
)
for group in self.param_groups:
for p in group["params"]:
m = self.state[p]["mass"]
print(
f" [{p.shape}] mass: "
f"min={m.min().item():.2e}, "
f"median={m.median().item():.2e}, "
f"max={m.max().item():.2e}"
)
# ------------------------------------------------------------------
# BAOAB step
# ------------------------------------------------------------------
[docs]
@torch.no_grad()
def step(self, closure):
"""Perform one BAOAB Langevin dynamics step.
Tracks the best-loss configuration and rolls back to it when the
loss exceeds ``loss_rollback_factor`` times the best loss seen so
far. This prevents the dynamics from permanently damaging the
structure while still allowing uphill exploration.
Args:
closure: A callable that re-evaluates the model and returns the
loss. The closure must call ``loss.backward()`` before
returning.
Returns:
The loss value from the closure evaluation.
"""
if closure is None:
raise RuntimeError("LangevinSA requires a closure")
T = self._get_temperature()
first_step = self._current_step == 0
# ---- Snapshot positions before B-A-O-A (for rollback on NaN) ----
snapshots = {}
for group in self.param_groups:
for p in group["params"]:
snapshots[id(p)] = p.data.clone()
# ---- B-A-O-A using stored prev_grad (skip B on first step) ----
for group in self.param_groups:
dt = group["dt"]
gamma = group["friction"]
adaptive = group["adaptive_masses"]
eps = group["mass_eps"]
half_dt = 0.5 * dt
alpha = math.exp(-gamma * dt)
max_v = group["max_step_size"] / dt
for p in group["params"]:
state = self.state[p]
# --- Initialise state on very first call ---
if "velocity" not in state:
state["prev_grad"] = None
if adaptive:
state["grad_sq_avg"] = torch.ones_like(p.data)
state["mass"] = None
# Velocity: if calibrated, already set; otherwise init now
state["velocity"] = torch.randn_like(p.data) * math.sqrt(T)
v = state["velocity"]
prev_grad = state["prev_grad"]
m = state["mass"]
# B: half-kick from stored gradient (skip on first step)
if not first_step and prev_grad is not None:
if m is not None:
v.add_(prev_grad / m, alpha=-half_dt)
else:
v.add_(prev_grad, alpha=-half_dt)
# A: half-drift (use p.add_ to increment _version
# so CachedForwardMixin sees the change)
p.add_(v, alpha=half_dt)
# O: Ornstein-Uhlenbeck thermostat
noise = torch.randn_like(v)
if m is not None:
sigma = ((T / m) * (1.0 - alpha * alpha)).sqrt()
else:
sigma = math.sqrt(T * (1.0 - alpha * alpha))
v.mul_(alpha).add_(noise * sigma)
# Velocity clamping: bound displacement per step
v.clamp_(-max_v, max_v)
# A: half-drift
p.add_(v, alpha=half_dt)
# ---- Evaluate loss + gradient at new position ----
with torch.enable_grad():
loss = closure()
# ---- NaN / loss-explosion protection ----
rollback = False
if not torch.isfinite(loss):
rollback = True
elif hasattr(self, "_best_loss"):
# Rollback if loss explodes beyond 3x best
if loss.item() > 3.0 * self._best_loss:
rollback = True
if rollback:
# Restore to best-known configuration if available,
# otherwise to the pre-step snapshot.
if hasattr(self, "_best_params"):
for group in self.param_groups:
for p in group["params"]:
p.data = self._best_params[id(p)].clone()
else:
for group in self.param_groups:
for p in group["params"]:
p.data = snapshots[id(p)].clone()
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
state["velocity"].zero_()
state["prev_grad"] = None
self._current_step += 1
return loss
# ---- Track best configuration ----
loss_val = loss.item()
if not hasattr(self, "_best_loss") or loss_val < self._best_loss:
self._best_loss = loss_val
self._best_params = {}
for group in self.param_groups:
for p in group["params"]:
self._best_params[id(p)] = p.data.clone()
# ---- Final B step: half-kick from new gradient + store ----
for group in self.param_groups:
dt = group["dt"]
adaptive = group["adaptive_masses"]
beta = group["mass_beta"]
eps = group["mass_eps"]
clip = group["gradient_clip"]
half_dt = 0.5 * dt
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.detach()
state = self.state[p]
v = state["velocity"]
# Optional gradient clipping
if clip is not None:
grad_norm = grad.norm()
if grad_norm > clip:
grad = grad * (clip / grad_norm)
# Update adaptive mass
if adaptive:
sq = state["grad_sq_avg"]
sq.mul_(beta).addcmul_(grad, grad, value=1.0 - beta)
state["mass"] = sq.sqrt() + eps
m = state["mass"]
# B: half-kick from new gradient
if m is not None:
v.add_(grad / m, alpha=-half_dt)
else:
v.add_(grad, alpha=-half_dt)
# Store gradient for next step's first B
state["prev_grad"] = grad.clone()
self._current_step += 1
return loss