"""
LBFGS-based refinement framework for crystallographic structure refinement.
This module provides an LBFGS optimizer-based refinement approach which has been
shown to converge much faster than first-order optimizers (Adam, SGD, etc.).
LBFGS typically reaches near-convergence in just 1-2 macro cycles.
The refinement composes three pieces:
- A persistent :class:`~torchref.refinement.loss_state.LossState` built once via
:meth:`~torchref.refinement.base_refinement.Refinement.complete_loss_state`.
- Persistent LBFGS optimizers (one per parameter group — xyz, adp+u+occupancy,
and the joint set). These are created lazily on first use and reused across
macro cycles so the construction cost is paid once.
- Scaler refinement, which runs its own local LossState + LBFGS step via
:meth:`~torchref.scaling.scaler_base.ScalerBase.refine_lbfgs` and is invoked
independently between body-parameter refinements.
Each body step clears the LBFGS curvature history for its own optimizer before
running. This is necessary because (a) the Hessian approximation does not
transfer across mode transitions (xyz → adp) and (b) scaler updates between
refine_xyz and refine_adp bump the parameters that feed the xray target, so
prior curvature information is stale.
"""
from typing import Optional
import numpy as np
import torch
from torchref.refinement.base_refinement import Refinement
[docs]
class LBFGSRefinement(Refinement):
"""
LBFGS-based refinement subclass using the L-BFGS optimizer for fast convergence.
L-BFGS (Limited-memory BFGS) is a quasi-Newton optimization method that
approximates the Hessian matrix, leading to much faster convergence than
first-order methods.
Key advantages:
- Converges in 1-2 macro cycles (vs 5+ for Adam)
- Better final R-factors
- More stable convergence
- Automatically handles step size via line search
Parameters
----------
target_mode : str, optional
X-ray target mode ('gaussian', 'ls', or 'ml'). Default is 'ml'.
*args
Passed to parent Refinement class.
**kwargs
Passed to parent Refinement class.
Attributes
----------
target_mode : str
Current X-ray target mode.
Examples
--------
Basic usage::
from torchref.refinement import LBFGSRefinement
refinement = LBFGSRefinement(
data_file='data.mtz',
pdb='model.pdb',
target_mode='ml'
)
refinement.refine(macro_cycles=2)
"""
LBFGS_DEFAULTS = dict(
lr=1.0,
max_iter=20,
history_size=100,
line_search_fn="strong_wolfe",
)
[docs]
def __init__(
self,
*args,
target_mode: str = "bhattacharyya",
sigma_m_scale: float = 1.0,
use_lossstate_scaler: bool = True,
**kwargs,
):
"""
Initialize LBFGS refinement.
Parameters
----------
target_mode : str, optional
X-ray target mode ('gaussian', 'ls', 'ml', 'bhattacharyya').
Default is 'bhattacharyya'.
sigma_m_scale : float, optional
Global multiplier for σ_m in the Bhattacharyya target only.
Ignored for other target modes. Default 1.0.
use_lossstate_scaler : bool, optional
If True (default), :meth:`refine_scaler` uses the full
:class:`LossState` with the body's x-ray target — so scaler and
body steps share one consistent loss. If False, falls back to
``Scaler.refine_lbfgs`` which minimises a standalone
``nll_xray`` and can pull scales in a different direction than
the body optimization.
*args
Passed to parent Refinement class.
**kwargs
Passed to parent Refinement class.
"""
super().__init__(*args, **kwargs)
self.sigma_m_scale = sigma_m_scale
# Set the X-ray target mode (uses the new target system from base class)
self.set_xray_target_mode(target_mode)
self.target_mode = target_mode
self.use_lossstate_scaler = use_lossstate_scaler
# Lazy persistent optimizers. Built on first access by
# _lbfgs_for_types so that LBFGSRefinement instances without a
# loaded model can still be constructed.
self._persistent_optimizers: dict = {}
[docs]
def xray_loss(self):
"""
Compute X-ray loss using the instantiated target.
Returns
-------
torch.Tensor
X-ray loss on work set.
"""
return self.xray_loss_work()
# =========================================================================
# Persistent optimizer machinery
# =========================================================================
def _lbfgs_for_types(self, types: tuple) -> torch.optim.LBFGS:
"""Return a persistent LBFGS optimizer over the given parameter types.
Optimizers are cached by the tuple of type names (e.g. ``("xyz",)``
or ``("adp", "u", "occupancy")``) and reused across refinement
calls. Curvature history must be cleared by the caller before each
use via :meth:`_reset_lbfgs_history`.
Parameters
----------
types : tuple of str
Parameter type names to include in the optimizer. Any of
``"xyz"``, ``"adp"``, ``"u"``, ``"occupancy"``.
Returns
-------
torch.optim.LBFGS
The cached optimizer, constructed on first call for this key.
"""
key = tuple(types)
opt = self._persistent_optimizers.get(key)
if opt is None:
params = self.model.parameters_of_types(types)
if not params:
raise RuntimeError(
f"No parameters found for types={types}; cannot build LBFGS."
)
opt = torch.optim.LBFGS(params, **self.LBFGS_DEFAULTS)
self._persistent_optimizers[key] = opt
return opt
@staticmethod
def _reset_lbfgs_history(optimizer: torch.optim.Optimizer) -> None:
"""Drop LBFGS curvature state so the next step starts from scratch.
The LBFGS two-loop recursion depends on recent (s, y) pairs sampled
under the *same* loss landscape. Between a refine_xyz and refine_adp
call the active parameter set changes; between any two body calls
the scaler's separate LBFGS has updated parameters the xray target
reads from. Either way the stored curvature is stale and can produce
bad search directions. Clearing state forces a fresh steepest-descent
direction on the first inner iteration.
"""
optimizer.state.clear()
# =========================================================================
# Refinement Methods
# =========================================================================
[docs]
def refine_scaler(self):
"""Refine scaler parameters against the full refinement loss.
Builds the body :class:`LossState` via
:meth:`complete_loss_state`, constructs a fresh LBFGS optimizer
over ``list(self.scaler.parameters())``, and delegates to
:meth:`LossState.step`. Because ``state.step`` disables
``requires_grad`` on every loss leaf outside the optimizer's
intent set, xyz / adp / u / occupancy are pinned for the duration
— only scaler parameters move.
The critical property is that the x-ray target used here is the
same one the body :meth:`refine_xyz` and :meth:`refine_adp` see.
The legacy :meth:`Scaler.refine_lbfgs` minimises a standalone
``nll_xray`` + ``U^2`` penalty, which can pull scales in a
different direction than a ``bhattacharyya`` or ``ml`` body loss
and leaves the body to chase a scaler that disagrees with its own
objective.
When ``use_lossstate_scaler`` is False, fall back to the legacy
:meth:`Scaler.refine_lbfgs` path.
Returns
-------
LossState or dict
``LossState`` with history if ``use_lossstate_scaler`` is
True, otherwise the metrics dict from
:meth:`Scaler.refine_lbfgs`.
"""
if not self.use_lossstate_scaler:
return self.scaler.refine_lbfgs()
state = self.complete_loss_state()
scaler_params = list(self.scaler.parameters())
if not scaler_params:
return state
optimizer = torch.optim.LBFGS(scaler_params, **self.LBFGS_DEFAULTS)
state.step(optimizer, context="lbfgs_refinement.refine_scaler")
return state
[docs]
def refine_xyz(self):
"""Refine Cartesian coordinates jointly with scaler parameters.
Scaler parameters (``log_scale``, ``U``, solvent terms) are
included in the same LBFGS call as ``xyz``. The joint curvature
lets xyz steps see the scaler as an anchor — residuals the scaler
can absorb do not have to be chased by atomic motion — and the
``adp/scaler_U`` and ``adp/scaler_log_scale`` priors bite on every
step, so nothing in the scaler drifts between refine_xyz and
refine_adp calls.
Returns
-------
LossState
State with history containing before/after loss values.
"""
state = self.complete_loss_state()
body = self.model.parameters_of_types(("xyz",))
params = body + list(self.scaler.parameters())
optimizer = torch.optim.LBFGS(params, **self.LBFGS_DEFAULTS)
state.step(optimizer, context="lbfgs_refinement.refine_xyz")
return state
[docs]
def refine_adp(self):
"""Refine ADP / U / occupancy jointly with scaler parameters.
Scaler parameters (``log_scale``, ``U``, solvent terms) are
included in the same LBFGS call as the ADP-block body parameters
so the joint curvature can slide along the atomic-B / scaler-U
degeneracy ridge together with the ``adp/scaler_U`` regularizer.
XYZ is left frozen.
Returns
-------
LossState
State with history containing before/after loss values.
"""
state = self.complete_loss_state()
body = self.model.parameters_of_types(("adp", "u", "occupancy"))
params = body + list(self.scaler.parameters())
optimizer = torch.optim.LBFGS(params, **self.LBFGS_DEFAULTS)
state.step(optimizer, context="lbfgs_refinement.refine_adp")
return state
[docs]
def refine_joint(self):
"""Joint LBFGS over every refinable parameter in one step.
Optimizes ``xyz``, ``adp``, ``u``, ``occupancy``, and every
scaler parameter (``log_scale``, anisotropic ``U``, solvent
terms) in a single LBFGS call. The joint curvature couples all
of them through the same x-ray target and through the
``adp/scaler_U`` / ``adp/scaler_log_scale`` priors — unlike
alternating refine_xyz → refine_adp, there's no "frozen partner"
in either half that could lock the step into a locally bad
direction.
Returns
-------
LossState
State with history containing before/after loss values.
"""
state = self.complete_loss_state()
body = self.model.parameters_of_types(
("xyz", "adp", "u", "occupancy")
)
params = body + list(self.scaler.parameters())
optimizer = torch.optim.LBFGS(params, **self.LBFGS_DEFAULTS)
state.step(optimizer, context="lbfgs_refinement.refine_joint")
return state
def _refine_everything_lbfgs_single_cycle(self, nsteps: int = 1):
"""Joint LBFGS over xyz + adp + u + occupancy for one macro cycle.
Used by :meth:`refine_everything`. Scaler is refined separately
before the body step.
"""
self.scaler.refine_lbfgs()
state = self.complete_loss_state()
optimizer = self._lbfgs_for_types(("xyz", "adp", "u", "occupancy"))
self._reset_lbfgs_history(optimizer)
state.run(
optimizer,
nsteps=nsteps,
context="lbfgs_refinement._refine_everything_lbfgs_single_cycle",
)
return state
# =========================================================================
# Training Loop for Policy Learning
# =========================================================================
[docs]
def run_training_trajectory(
self,
policy_weighting,
n_steps: int = 10,
pdb_id: str = "",
structure_path: str = "",
sf_path: str = "",
seed: Optional[int] = None,
policy_version: Optional[str] = None,
):
"""
Run a training trajectory with policy-guided refinement.
This method runs a sequence of refinement steps using a policy
to select component weights. It records state-action-reward tuples
for training the policy with AWR or similar algorithms.
Parameters
----------
policy_weighting : PolicyComponentWeighting
Policy weighting scheme (should be in training mode with sampling).
n_steps : int, optional
Number of refinement steps in the trajectory (default: 10).
pdb_id : str, optional
PDB identifier for recording.
structure_path : str, optional
Path to structure file for recording.
sf_path : str, optional
Path to structure factors file for recording.
seed : int, optional
Random seed for reproducibility.
policy_version : str, optional
Version identifier of the policy being used.
Returns
-------
TrajectoryData
Complete trajectory with state-action-reward tuples.
"""
import time
start_time = time.time()
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
policy_weighting.start_recording(
pdb_id=pdb_id,
structure_path=structure_path,
sf_path=sf_path,
seed=seed,
policy_version=policy_version,
)
try:
self.scaler.refine_lbfgs()
optimizer = self._lbfgs_for_types(("xyz",))
for step in range(n_steps):
if self.verbose > 1:
print(f"Step {step + 1}/{n_steps}")
state = self.complete_loss_state()
# Evaluate once to populate loss cache (feature extraction).
with torch.no_grad():
state.aggregate()
# Apply policy weights (this also records the step).
policy_weighting.apply_to_state(state)
# Policy just rewrote the weights, so the old LBFGS
# curvature is for a different loss landscape — reset.
self._reset_lbfgs_history(optimizer)
state.step(
optimizer,
context="lbfgs_refinement.run_training_trajectory",
)
policy_weighting.increment_step()
trajectory = policy_weighting.stop_recording()
trajectory.total_time = time.time() - start_time
trajectory.success = True
except Exception as e:
trajectory = policy_weighting.stop_recording()
if trajectory is not None:
trajectory.success = False
trajectory.error_message = str(e)
trajectory.total_time = time.time() - start_time
raise
return trajectory
[docs]
def run_training_trajectory_joint(
self,
policy_weighting,
n_steps: int = 10,
pdb_id: str = "",
structure_path: str = "",
sf_path: str = "",
seed: Optional[int] = None,
policy_version: Optional[str] = None,
):
"""
Run a training trajectory with joint XYZ+ADP refinement.
Similar to :meth:`run_training_trajectory` but refines xyz, adp,
u, and occupancy together in each step. The LBFGS curvature
history is reset at the start of each policy step because the
weight updates invalidate any prior Hessian approximation.
Parameters
----------
policy_weighting : PolicyComponentWeighting
Policy weighting scheme (should be in training mode).
n_steps : int, optional
Number of refinement steps (default: 10).
pdb_id, structure_path, sf_path : str, optional
Identifiers for trajectory recording.
seed : int, optional
Random seed for reproducibility.
policy_version : str, optional
Policy version identifier.
Returns
-------
TrajectoryData
Complete trajectory with state-action-reward tuples.
"""
import time
start_time = time.time()
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
policy_weighting.start_recording(
pdb_id=pdb_id,
structure_path=structure_path,
sf_path=sf_path,
seed=seed,
policy_version=policy_version,
)
try:
self.scaler.refine_lbfgs()
optimizer = self._lbfgs_for_types(("xyz", "adp", "u", "occupancy"))
for step in range(n_steps):
if self.verbose > 1:
print(f"Step {step + 1}/{n_steps}")
state = self.complete_loss_state()
with torch.no_grad():
state.aggregate()
policy_weighting.apply_to_state(state)
self._reset_lbfgs_history(optimizer)
state.step(
optimizer,
context="lbfgs_refinement.run_training_trajectory_joint",
)
policy_weighting.increment_step()
trajectory = policy_weighting.stop_recording()
trajectory.total_time = time.time() - start_time
trajectory.success = True
except Exception as e:
trajectory = policy_weighting.stop_recording()
if trajectory is not None:
trajectory.success = False
trajectory.error_message = str(e)
trajectory.total_time = time.time() - start_time
raise
return trajectory
[docs]
def refine(self, macro_cycles=5):
"""
Run full LBFGS refinement cycle (ADP + XYZ).
Parameters
----------
macro_cycles : int, optional
Number of refinement cycles to perform. Default is 5.
Returns
-------
dict
History dictionary with all metrics per cycle (hierarchical structure).
"""
i = 0
while True:
i += 1
master_key = f"refinement_{i}"
if master_key not in self.history:
break
self.history[master_key] = []
# Clear logger history for fresh refinement
self.logger.clear()
for cycle in range(macro_cycles):
cycle_dict = {
"cycle": cycle + 1,
"before_scaling": {},
"after_scaling": {},
"xyz": {"before": {}, "after": {}, "weights": {}},
"adp": {"before": {}, "after": {}, "weights": {}},
}
if self.verbose > 0:
print(f"\n{'='*60}")
print(f"LBFGS Refinement - Cycle {cycle+1}/{macro_cycles}")
print(f"{'='*60}")
with torch.no_grad():
before_scaling = self.collect_metrics()
cycle_dict["before_scaling"] = before_scaling
if getattr(self.scaler, "solvent", None) is not None:
self.scaler.solvent.update_solvent()
self.reflection_data.find_outliers(
self.model, self.scaler, z_threshold=5.0
)
with torch.no_grad():
after_scaling = self.collect_metrics()
cycle_dict["after_scaling"] = after_scaling
if self.verbose > 0:
print(
f"After scaling: Rwork={after_scaling['rwork']:.4f}, "
f"Rfree={after_scaling['rfree']:.4f}"
)
self.logger.record(label="before_xyz")
cycle_dict["xyz"]["before"] = self.collect_metrics()
self.refine_xyz()
self.logger.record(label="after_xyz")
cycle_dict["xyz"]["after"] = self.collect_metrics()
if self.verbose > 0:
self.logger.compare(
label_before="before_xyz",
label_after="after_xyz",
title="XYZ Refinement",
)
self.logger.record(label="before_adp")
cycle_dict["adp"]["before"] = self.collect_metrics()
self.refine_adp()
self.logger.record(label="after_adp")
cycle_dict["adp"]["after"] = self.collect_metrics()
if self.verbose > 0:
self.logger.compare(
label_before="before_adp",
label_after="after_adp",
title="ADP Refinement",
)
self.refine_scaler()
self.history[master_key].append(cycle_dict)
return self.history
[docs]
def refine_everything(self, macro_cycles=5):
"""
Run full LBFGS refinement cycle (ADP + XYZ) without weight screening.
Parameters
----------
macro_cycles : int, optional
Number of refinement cycles to perform. Default is 5.
Returns
-------
dict
History dictionary with all metrics per cycle (hierarchical structure).
"""
self.model.unfreeze_all()
i = 0
while True:
i += 1
master_key = f"refinement_everything_{i}"
if master_key not in self.history:
break
self.history[master_key] = []
self.history["initial"] = self.collect_metrics()
self.logger.clear()
for cycle in range(macro_cycles):
cycle_dict = {
"cycle": cycle + 1,
"before_scaling": {},
"after_scaling": {},
"after_refinement": {},
}
if self.verbose > 0:
print(f"\n{'='*60}")
print(f"LBFGS Refinement Everything - Cycle {cycle+1}/{macro_cycles}")
print(f"{'='*60}")
self.get_scales()
self.logger.record(label="after_scaling")
with torch.no_grad():
after_scaling = self.collect_metrics()
cycle_dict["after_scaling"] = after_scaling
if self.verbose > 0:
print(
f"After scaling: Rwork={after_scaling['rwork']:.4f}, "
f"Rfree={after_scaling['rfree']:.4f}"
)
self._refine_everything_lbfgs_single_cycle()
self.logger.record(label="after_refinement")
with torch.no_grad():
after_refinement = self.collect_metrics()
cycle_dict["after_refinement"] = after_refinement
if self.verbose > 0:
print(
f"After refinement: Rwork={after_refinement['rwork']:.4f}, "
f"Rfree={after_refinement['rfree']:.4f}"
)
self.logger.compare(
label_before="after_scaling",
label_after="after_refinement",
title="Joint XYZ+ADP Refinement",
)
self.history[master_key].append(cycle_dict)
return self.history