torchref.refinement.lbfgs_refinement module

LBFGS-based refinement framework for crystallographic structure refinement.

This module provides an LBFGS optimizer-based refinement approach which has been shown to converge much faster than first-order optimizers (Adam, SGD, etc.). LBFGS typically reaches near-convergence in just 1-2 macro cycles.

The refinement composes three pieces:

  • A persistent LossState built once via complete_loss_state().

  • Persistent LBFGS optimizers (one per parameter group — xyz, adp+u+occupancy, and the joint set). These are created lazily on first use and reused across macro cycles so the construction cost is paid once.

  • Scaler refinement, which runs its own local LossState + LBFGS step via refine_lbfgs() and is invoked independently between body-parameter refinements.

Each body step clears the LBFGS curvature history for its own optimizer before running. This is necessary because (a) the Hessian approximation does not transfer across mode transitions (xyz → adp) and (b) scaler updates between refine_xyz and refine_adp bump the parameters that feed the xray target, so prior curvature information is stale.

class torchref.refinement.lbfgs_refinement.LBFGSRefinement(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]

Bases: Refinement

LBFGS-based refinement subclass using the L-BFGS optimizer for fast convergence.

L-BFGS (Limited-memory BFGS) is a quasi-Newton optimization method that approximates the Hessian matrix, leading to much faster convergence than first-order methods.

Key advantages:

  • Converges in 1-2 macro cycles (vs 5+ for Adam)

  • Better final R-factors

  • More stable convergence

  • Automatically handles step size via line search

Parameters:
  • target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, or ‘ml’). Default is ‘ml’.

  • *args – Passed to parent Refinement class.

  • **kwargs – Passed to parent Refinement class.

target_mode

Current X-ray target mode.

Type:

str

Examples

Basic usage:

from torchref.refinement import LBFGSRefinement

refinement = LBFGSRefinement(
    data_file='data.mtz',
    pdb='model.pdb',
    target_mode='ml'
)
refinement.refine(macro_cycles=2)
LBFGS_DEFAULTS = {'history_size': 100, 'line_search_fn': 'strong_wolfe', 'lr': 1.0, 'max_iter': 20}
__init__(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]

Initialize LBFGS refinement.

Parameters:
  • target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, ‘ml’, ‘bhattacharyya’). Default is ‘bhattacharyya’.

  • sigma_m_scale (float, optional) – Global multiplier for σ_m in the Bhattacharyya target only. Ignored for other target modes. Default 1.0.

  • use_lossstate_scaler (bool, optional) – If True (default), refine_scaler() uses the full LossState with the body’s x-ray target — so scaler and body steps share one consistent loss. If False, falls back to Scaler.refine_lbfgs which minimises a standalone nll_xray and can pull scales in a different direction than the body optimization.

  • *args – Passed to parent Refinement class.

  • **kwargs – Passed to parent Refinement class.

xray_loss()[source]

Compute X-ray loss using the instantiated target.

Returns:

X-ray loss on work set.

Return type:

torch.Tensor

refine_scaler()[source]

Refine scaler parameters against the full refinement loss.

Builds the body LossState via complete_loss_state(), constructs a fresh LBFGS optimizer over list(self.scaler.parameters()), and delegates to LossState.step(). Because state.step disables requires_grad on every loss leaf outside the optimizer’s intent set, xyz / adp / u / occupancy are pinned for the duration — only scaler parameters move.

The critical property is that the x-ray target used here is the same one the body refine_xyz() and refine_adp() see. The legacy Scaler.refine_lbfgs() minimises a standalone nll_xray + U^2 penalty, which can pull scales in a different direction than a bhattacharyya or ml body loss and leaves the body to chase a scaler that disagrees with its own objective.

When use_lossstate_scaler is False, fall back to the legacy Scaler.refine_lbfgs() path.

Returns:

LossState with history if use_lossstate_scaler is True, otherwise the metrics dict from Scaler.refine_lbfgs().

Return type:

LossState or dict

refine_xyz()[source]

Refine Cartesian coordinates jointly with scaler parameters.

Scaler parameters (log_scale, U, solvent terms) are included in the same LBFGS call as xyz. The joint curvature lets xyz steps see the scaler as an anchor — residuals the scaler can absorb do not have to be chased by atomic motion — and the adp/scaler_U and adp/scaler_log_scale priors bite on every step, so nothing in the scaler drifts between refine_xyz and refine_adp calls.

Returns:

State with history containing before/after loss values.

Return type:

LossState

refine_adp()[source]

Refine ADP / U / occupancy jointly with scaler parameters.

Scaler parameters (log_scale, U, solvent terms) are included in the same LBFGS call as the ADP-block body parameters so the joint curvature can slide along the atomic-B / scaler-U degeneracy ridge together with the adp/scaler_U regularizer. XYZ is left frozen.

Returns:

State with history containing before/after loss values.

Return type:

LossState

refine_joint()[source]

Joint LBFGS over every refinable parameter in one step.

Optimizes xyz, adp, u, occupancy, and every scaler parameter (log_scale, anisotropic U, solvent terms) in a single LBFGS call. The joint curvature couples all of them through the same x-ray target and through the adp/scaler_U / adp/scaler_log_scale priors — unlike alternating refine_xyz → refine_adp, there’s no “frozen partner” in either half that could lock the step into a locally bad direction.

Returns:

State with history containing before/after loss values.

Return type:

LossState

run_training_trajectory(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]

Run a training trajectory with policy-guided refinement.

This method runs a sequence of refinement steps using a policy to select component weights. It records state-action-reward tuples for training the policy with AWR or similar algorithms.

Parameters:
  • policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode with sampling).

  • n_steps (int, optional) – Number of refinement steps in the trajectory (default: 10).

  • pdb_id (str, optional) – PDB identifier for recording.

  • structure_path (str, optional) – Path to structure file for recording.

  • sf_path (str, optional) – Path to structure factors file for recording.

  • seed (int, optional) – Random seed for reproducibility.

  • policy_version (str, optional) – Version identifier of the policy being used.

Returns:

Complete trajectory with state-action-reward tuples.

Return type:

TrajectoryData

run_training_trajectory_joint(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]

Run a training trajectory with joint XYZ+ADP refinement.

Similar to run_training_trajectory() but refines xyz, adp, u, and occupancy together in each step. The LBFGS curvature history is reset at the start of each policy step because the weight updates invalidate any prior Hessian approximation.

Parameters:
  • policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode).

  • n_steps (int, optional) – Number of refinement steps (default: 10).

  • pdb_id (str, optional) – Identifiers for trajectory recording.

  • structure_path (str, optional) – Identifiers for trajectory recording.

  • sf_path (str, optional) – Identifiers for trajectory recording.

  • seed (int, optional) – Random seed for reproducibility.

  • policy_version (str, optional) – Policy version identifier.

Returns:

Complete trajectory with state-action-reward tuples.

Return type:

TrajectoryData

refine(macro_cycles=5)[source]

Run full LBFGS refinement cycle (ADP + XYZ).

Parameters:

macro_cycles (int, optional) – Number of refinement cycles to perform. Default is 5.

Returns:

History dictionary with all metrics per cycle (hierarchical structure).

Return type:

dict

refine_everything(macro_cycles=5)[source]

Run full LBFGS refinement cycle (ADP + XYZ) without weight screening.

Parameters:

macro_cycles (int, optional) – Number of refinement cycles to perform. Default is 5.

Returns:

History dictionary with all metrics per cycle (hierarchical structure).

Return type:

dict