torchref.refinement.base_refinement module

Base class for crystallographic refinement.

class torchref.refinement.base_refinement.Refinement(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]

Bases: DeviceMixin, DebugMixin, Module

Refinement class to handle the overall crystallographic refinement process.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    refinement = Refinement()  # Creates empty shell with submodules
    refinement.load_state_dict(torch.load('refinement.pt'))
    
  2. Full initialization with file paths:

    refinement = Refinement(data_file='data.mtz', pdb='model.pdb')
    
Parameters:
  • data_file (str, optional) – Path to MTZ or CIF file containing reflection data.

  • pdb (str, optional) – Path to PDB or CIF file containing initial model.

  • cif (str, optional) – Path to CIF file for restraints.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • max_res (float, optional) – Maximum resolution for reflections.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.

  • nbins (int, optional) – Number of resolution bins. Default is 10.

device

Computation device.

Type:

torch.device

verbose

Verbosity level.

Type:

int

reflection_data

Reflection data container.

Type:

ReflectionData

model

Structure factor model (includes lazy restraints via model.restraints).

Type:

ModelFT

scaler

Scale factor calculator.

Type:

Scaler

weighter

Loss weighting module.

Type:

LossWeightingModule

__init__(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]

Initialize Refinement.

If data_file and pdb are provided, fully initializes the refinement. If not provided (empty init), creates a shell with empty submodules ready for load_state_dict().

Parameters:
  • data_file (str, optional) – Path to MTZ or CIF file containing reflection data.

  • pdb (str, optional) – Path to PDB or CIF file containing initial model.

  • cif (str, optional) – Path to CIF file for restraints.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • max_res (float, optional) – Maximum resolution for reflections.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.

  • nbins (int, optional) – Number of resolution bins. Default is 10.

set_xray_target_mode(mode)[source]

Change the X-ray target mode.

Parameters:

mode (str) – X-ray target mode. Options are ‘gaussian’, ‘ls’, or ‘ml’.

property data

Expose reflection_data as ‘data’ for weighting module compatibility.

Returns:

The reflection data container.

Return type:

ReflectionData

property loss_state: LossState

Get or create the persistent LossState.

The LossState is created once and reused across refinement cycles. Targets are registered once; weights are updated each cycle.

Returns:

The persistent loss state with targets registered.

Return type:

LossState

property logger: Logger

Get or create the Logger for this refinement.

Returns:

Logger instance linked to the persistent LossState.

Return type:

Logger

reset_loss_state()[source]

Reset the persistent LossState and Logger.

Call this if targets need to be re-registered (e.g., after changing target modes or reinitializing targets).

get_scales()[source]
setup_scaler()[source]
parameters(recurse=True)[source]

Return unique parameters from this module and all submodules.

Uses the default Module.parameters() to gather parameters, then removes duplicates while preserving order to avoid passing the same tensor multiple times to the optimizer.

Parameters:

recurse (bool, optional) – If True, yields parameters of this module and all submodules. Default is True.

Returns:

List of unique parameter tensors.

Return type:

list

get_fcalc(hkl=None, recalc=False)[source]
get_fcalc_scaled(hkl=None, recalc=False)[source]
adp_loss()[source]

Compute total ADP loss using TotalADPTarget.

This combines:

  • Bond-based similarity (SIMU-like)

  • Spread control (tighter than KL)

  • Bounds penalty

Returns:

Total ADP loss value.

Return type:

torch.Tensor

get_F_calc(hkl=None, recalc=False)[source]
get_F_calc_scaled(hkl=None, recalc=False)[source]
nll_xray()[source]

Compute X-ray negative log-likelihood for work and test sets.

Returns:

Tuple of (work_nll, test_nll) tensors.

Return type:

tuple of torch.Tensor

xray_loss_work()[source]

Compute X-ray loss on work set using instantiated target.

Returns:

X-ray loss on work set.

Return type:

torch.Tensor

xray_loss_test()[source]

Compute X-ray loss on test set using instantiated target.

Returns:

X-ray loss on test set.

Return type:

torch.Tensor

bond_loss()[source]

Compute bond length NLL via geometry_target.

Returns:

Bond length NLL loss.

Return type:

torch.Tensor

angle_loss()[source]

Compute angle NLL via geometry_target.

Returns:

Angle NLL loss.

Return type:

torch.Tensor

torsion_loss()[source]

Compute torsion angle NLL via geometry_target.

Returns:

Torsion angle NLL loss.

Return type:

torch.Tensor

geometry_loss()[source]

Compute total geometry NLL using TotalGeometryTarget.

Returns:

Total geometry NLL loss.

Return type:

torch.Tensor

loss()[source]

Compute total loss using LossState pipeline.

Creates a LossState, populates meta, caches losses, updates weights, and returns the aggregated weighted loss.

Returns:

Total weighted loss.

Return type:

torch.Tensor

setup_component_weighting()[source]

Set up component weighting with ResolutionWeighting + OverfittingWeighting.

populate_state_meta(state)[source]

Populate LossState.meta with all model-level data.

Called once per macro cycle before weighting schemes are applied. This is the single location where refinement data is extracted into state.

Parameters:

state (LossState) – State to populate with meta data.

Returns:

State with meta populated.

Return type:

LossState

update_weights(state, multiply=False)[source]

Compute weights from component_weighting and update state. Weights are clipped to [0.01, 100.0] to avoid extreme values.

Parameters:
  • state (LossState) – State with meta populated.

  • multiply (bool, optional) – If True, multiply existing weights by computed weights. If False, replace existing weights with computed weights.

Returns:

State with weights updated.

Return type:

LossState

create_loss_state()[source]

Create a configured LossState for optimization.

Deprecated since version Use: the loss_state property instead for the persistent state. This method is kept for backwards compatibility.

Sets up a LossState with all targets registered as callables with hierarchical naming (e.g., ‘geometry/bond’, ‘adp/simu’). Weights are applied from component_weighting.

Usage:

from torchref.utils import validate_loss

state = refinement.create_loss_state() params = list(refinement.parameters())

# Log initial state state.aggregate(log_values=True)

# In an LBFGS closure, wrap with validate_loss so non-finite # losses warn + reject the step instead of poisoning the run. def closure():

optimizer.zero_grad() loss = state.aggregate() loss.backward() ok = validate_loss(

loss, state=state, parameters=params, context=”my_refinement”, raise_on_fail=False,

) if not ok:

for p in params:
if p.grad is not None:

p.grad.zero_()

return torch.full_like(loss.detach(), float(“inf”))

return loss

optimizer.step(closure)

# Log final state state.new_entry() state.aggregate(log_values=True)

Returns:

Configured LossState with targets and weights.

Return type:

LossState

complete_loss_state()[source]

Update and return the persistent LossState.

Updates the persistent LossState with current meta, target info, cached losses, and weights. The state is reused across cycles.

The cached active-parameter leaf set is not refreshed here. Stale leaves are not a correctness hazard: a leaf that’s in the set but whose Parameter object was replaced externally (e.g. by Model.freeze) just gets ignored by _freeze_graph_extras, which costs a marginal amount of wasted backward work but never produces wrong answers. If you do call Model.freeze / Model.unfreeze between LossState creation and a refinement step, call state.refresh_loss_leaves() explicitly.

Returns:

Complete LossState with targets, meta, losses, and weights.

Return type:

LossState

xray_loss()[source]

Compute X-ray loss on work set.

Returns:

X-ray loss on work set.

Return type:

torch.Tensor

restraints_loss()[source]

Compute total geometry restraints loss.

Returns:

Total geometry restraints loss.

Return type:

torch.Tensor

collect_metrics()[source]

Collect all metrics from component_weighting.stats().

This is the standard method for gathering refinement metrics for logging. Uses the centralized component_weighting module for all statistics. Returns full unfiltered stats - filtering is done at display time.

Returns:

Dictionary with all metrics (unfiltered, with StatEntry objects).

Return type:

dict

add_target_info_to_state(state)[source]

Add target information from geometry and ADP targets to LossState.meta.

Deprecated since version This: method is no longer needed. Use complete_loss_state() instead, which handles all state setup in one call.

Parameters:

state (LossState) – Current loss state. Meta will be updated with target info.

Returns:

Updated loss state (unchanged).

Return type:

LossState

get_rfactor()[source]
update_outliers(z_threshold=4.0)[source]
plot_fcalc_vs_fobs(outpath='fcalc_vs_fobs.png')[source]
write_out_mtz(out_mtz_path='refined_output.mtz')[source]
collect_deposition_metadata(metadata=None)[source]

Collect refinement statistics into a RefinementMetadata object.

Reuses existing statistics from collect_metrics(), get_rfactor(), and reflection data attributes.

Parameters:

metadata (RefinementMetadata, optional) – Existing metadata to merge with (e.g. from input file pass-through). Refinement statistics take precedence over pass-through values.

Returns:

Metadata populated with final refinement statistics.

Return type:

RefinementMetadata

write_out_pdb(out_pdb_path='refined_output.pdb', metadata=None)[source]

Write refined PDB with optional metadata header.

Parameters:
  • out_pdb_path (str) – Output PDB file path.

  • metadata (RefinementMetadata, optional) – Metadata for PDB header. If None, auto-collected from refinement.

write_out_cif(out_cif_path='refined_output.cif', metadata=None)[source]

Write refined coordinates as mmCIF with metadata.

Parameters:
  • out_cif_path (str) – Output mmCIF file path.

  • metadata (RefinementMetadata, optional) – Metadata for mmCIF categories. If None, auto-collected from refinement.

save_state(path)[source]

Save the complete state of the refinement to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the refinement from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, optional) – Whether to strictly enforce that keys match. Default is True.

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1)[source]

Create a fully initialized Refinement from a state dictionary.

This is the recommended way to restore a Refinement from a saved state. It creates the proper submodules using their respective create_from_state_dict methods, then calls PyTorch’s default load_state_dict.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(refinement.state_dict(), …) or from loading a checkpoint file.

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

Returns:

Fully initialized instance with restored state.

Return type:

Refinement

Examples

Save and load refinement state:

# Save
torch.save(refinement.state_dict(), 'refinement.pt')

# Load
state = torch.load('refinement.pt')
refinement = Refinement.create_from_state_dict(state)

# Continue refinement
rwork, rfree = refinement.get_rfactor()
print(f"Restored at R-work={rwork:.4f}, R-free={rfree:.4f}")