Source code for torchref.cli.refine

#!/usr/bin/env python3 -u

"""
Command-line script for LBFGS crystallographic refinement using torchref.

Supports the Bhattacharyya overlap target by default; Gaussian / least-squares
/ maximum-likelihood targets remain available via ``--xray-mode``.

Examples
--------
::

    # Default: Bhattacharyya target, joint XYZ+ADP+scaler LBFGS
    torchref.refine -m model.pdb -sf reflections.mtz -o output_dir/

    # 10 refinement cycles
    torchref.refine -m model.pdb -sf reflections.mtz -o output/ -n 10

    # Separated XYZ then ADP cycles
    torchref.refine -m model.pdb -sf reflections.mtz -o output/ --mode refine

    # Legacy maximum-likelihood target
    torchref.refine -m model.pdb -sf reflections.mtz -o output/ --xray-mode ml
"""

import argparse
import json
import sys
from pathlib import Path

import torch

from torchref.cli._common import (
    add_dmin_arg,
    add_general_args,
    add_metadata_args,
    add_n_cycles_arg,
    add_outdir_arg,
    add_output_format_args,
    add_single_model_args,
    add_weights_arg,
    build_column_names,
    configure_unbuffered_output,
    parse_weights,
    register_timing,
    parse_device_str,
    validate_files,
    write_refinement_outputs,
)
from torchref.utils.serialization import convert_to_serializable

configure_unbuffered_output()

# Import stats module early to patch json with StatEntry encoder
import torchref.utils.stats  # noqa: F401,E402


[docs] def main(): parser = argparse.ArgumentParser( description="Run LBFGS crystallographic refinement with torchref.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Default: Bhattacharyya target, joint XYZ+ADP+scaler LBFGS torchref.refine -m model.pdb -sf reflections.mtz -o output_dir/ # 10 refinement cycles torchref.refine -m model.pdb -sf reflections.mtz -o output/ -n 10 # Separated XYZ then ADP cycles torchref.refine -m model.pdb -sf reflections.mtz -o output/ --mode refine # Legacy maximum-likelihood target torchref.refine -m model.pdb -sf reflections.mtz -o output/ --xray-mode ml """, ) add_single_model_args(parser) output = parser.add_argument_group("Output") add_outdir_arg(output) add_output_format_args(output) add_metadata_args(output) refine_group = parser.add_argument_group("Refinement") add_n_cycles_arg(refine_group) refine_group.add_argument( "--mode", type=str, default="everything", choices=["everything", "refine"], help='Refinement mode: "everything" for joint XYZ+ADP+scaler LBFGS, ' '"refine" for separated XYZ then ADP cycles (default: "everything")', ) refine_group.add_argument( "--xray-mode", type=str, default="bhattacharyya", choices=["gaussian", "ls", "ml", "bhattacharyya"], help="X-ray target function. 'bhattacharyya' (default) uses the " "Bhattacharyya overlap loss with first-principles model error " "estimation and needs no manual weight tuning.", ) refine_group.add_argument( "--sigma-m-scale", type=float, default=1.0, help="Global multiplier applied to σ_m for the Bhattacharyya target. " "Ignored for other targets. Default 1.0.", ) add_weights_arg(refine_group) res = parser.add_argument_group("Resolution") add_dmin_arg(res) add_general_args(parser) args = parser.parse_args() register_timing() # Parse weights manual_weights, weights_err = parse_weights(args.weights) if weights_err: print(f"Error: {weights_err}", file=sys.stderr) sys.exit(1) # Validate inputs model_path = Path(args.model) sf_path = Path(args.structure_factor) outdir = Path(args.outdir) validate_files( [(str(model_path), "Model"), (str(sf_path), "Structure factors")], exit_on_error=True, ) outdir.mkdir(parents=True, exist_ok=True) # Import here to avoid slow startup for --help try: from torchref.refinement.lbfgs_refinement import LBFGSRefinement except ImportError as e: print(f"Error: Failed to import torchref modules: {e}", file=sys.stderr) print("Please ensure torchref is properly installed.", file=sys.stderr) sys.exit(1) if args.verbose > 0: print("=" * 80) print("TorchRef LBFGS Refinement") print("=" * 80) print(f"Model: {model_path}") print(f"Structure factor: {sf_path}") print(f"Output directory: {outdir}") print(f"Refinement mode: {args.mode}") print(f"X-ray target: {args.xray_mode}") if args.xray_mode == "bhattacharyya": print(f"sigma_m scale: {args.sigma_m_scale}") print(f"Refinement cycles: {args.n_cycles}") print(f"Device: {args.device}") if args.dmin: print(f"Resolution cutoff: {args.dmin:.2f} A") if manual_weights: print(f"Manual weights: {json.dumps(manual_weights)}") print("=" * 80) print() sys.stdout.flush() device = parse_device_str(args.device) if args.verbose > 0: print("Initializing refinement...") sys.stdout.flush() column_names = build_column_names(args.column_structure_factor, args.column_sigma) refinement = LBFGSRefinement( data_file=str(sf_path), pdb=str(model_path), cif=args.cif, verbose=args.verbose, max_res=args.dmin, device=device, column_names=column_names, target_mode=args.xray_mode, sigma_m_scale=args.sigma_m_scale, manual_weights=manual_weights or None, ) if args.verbose > 0: print("Refinement initialized successfully.\n") sys.stdout.flush() # Run refinement try: if args.verbose > 0: print(f"Starting refinement with {args.n_cycles} macro cycles...\n") sys.stdout.flush() if args.mode == "everything": refinement.refine_everything(macro_cycles=args.n_cycles) else: refinement.refine(macro_cycles=args.n_cycles) refinement.get_scales() if args.verbose > 0: print("\nRefinement completed successfully.") sys.stdout.flush() except Exception as e: refinement.debug_on_error(e) raise e if args.verbose > 0: print(f"\nSaving results to {outdir}...") sys.stdout.flush() # Save refined structure(s) with metadata outputs = write_refinement_outputs(refinement, outdir, args, verbose=args.verbose) # Save refined structure factors output_mtz = outdir / "refined.mtz" refinement.write_out_mtz(str(output_mtz)) if args.verbose > 0: print(f" Refined structure factors: {output_mtz}") sys.stdout.flush() # Save refinement history as JSON output_json = outdir / "refinement_history.json" history_data = { "input_files": { "model": str(model_path), "structure_factor": str(sf_path), "cif": args.cif, }, "parameters": { "n_cycles": args.n_cycles, "mode": args.mode, "xray_mode": args.xray_mode, "sigma_m_scale": args.sigma_m_scale, "weights": manual_weights if manual_weights else None, "dmin": args.dmin, "device": str(device), }, "history": refinement.history if hasattr(refinement, "history") else {}, "final_statistics": {}, } # Add final R-factors if available try: work_nll, test_nll = refinement.nll_xray() hkl, fobs, sigma, rfree = refinement.reflection_data() fcalc = refinement.get_F_calc_scaled(hkl, recalc=True) work_mask = rfree test_mask = ~rfree r_work = torch.sum( torch.abs(fobs[work_mask] - fcalc[work_mask]) ) / torch.sum(fobs[work_mask]) r_free = torch.sum( torch.abs(fobs[test_mask] - fcalc[test_mask]) ) / torch.sum(fobs[test_mask]) history_data["final_statistics"] = { "R_work": float(r_work.item()), "R_free": float(r_free.item()), "NLL_work": float(work_nll.item()), "NLL_test": float(test_nll.item()), "n_reflections_work": int(work_mask.sum().item()), "n_reflections_test": int(test_mask.sum().item()), } except Exception as e: if args.verbose > 1: print(f" Warning: Could not compute final statistics: {e}") with open(output_json, "w") as f: json.dump(convert_to_serializable(history_data), f, indent=2) if args.verbose > 0: print(f" Refinement history: {output_json}") sys.stdout.flush() # Print final summary if args.verbose > 0: print("\n" + "=" * 80) print("Refinement Summary") print("=" * 80) if history_data["final_statistics"]: stats = history_data["final_statistics"] print( f"R-work: {stats['R_work']:.4f} ({stats['n_reflections_work']} reflections)" ) print( f"R-free: {stats['R_free']:.4f} ({stats['n_reflections_test']} reflections)" ) print(f"NLL work: {stats['NLL_work']:.2f}") print(f"NLL test: {stats['NLL_test']:.2f}") print("=" * 80) print("\nOutput files:") for fmt in ("pdb", "cif"): if outputs.get(fmt) is not None: print(f" - {outputs[fmt]}") if (outdir / "refined.mtz").exists(): print(f" - {outdir / 'refined.mtz'}") print(f" - {output_json}") print("\nRefinement completed successfully!") sys.stdout.flush() return 0
if __name__ == "__main__": sys.exit(main())