torchref.refinement.weighting.component_weighting module

class torchref.refinement.weighting.component_weighting.BaseWeighting(device=None, **kwargs)[source]

Bases: DeviceMixin, Module, ABC

Abstract base class for weighting schemes using LossState.

Weighting schemes: - Are set up without state (only device and hyperparameters) - Receive state only when computing weights via forward() - Return weights dict (do not modify state directly)

All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.

Parameters:

device (torch.device, optional) – Computation device. Defaults to the configured device.current.

name

Unique name for this weighting scheme.

Type:

str

device

Computation device.

Type:

torch.device

name: str = 'base_weighting'
__init__(device=None, **kwargs)[source]
abstractmethod forward(state)[source]

Compute weights from the current LossState.

Access data via state[“key”] or state.get(“key”, default). Meta data (rwork, rfree, etc.) is in state.meta. Cached losses are in state._losses.

Parameters:

state (LossState) – Current loss state with meta and _losses populated.

Returns:

Dictionary mapping component names to weights.

Return type:

Dict[str, float]

stats(state=None)[source]

Return statistics for reporting.

Parameters:

state (LossState, optional) – If provided, can pull data from LossState.

Returns:

Statistics dictionary with StatEntry objects.

Return type:

Dict[str, StatEntry]

torchref.refinement.weighting.component_weighting.WeightingScheme

alias of BaseWeighting

class torchref.refinement.weighting.component_weighting.ResolutionWeighting(device=None, base_w_geometry=1.0, base_w_adp=1.0, d_ref=2.0, alpha=0.0)[source]

Bases: BaseWeighting

Base prior strength with optional resolution-dependent correction.

With proper NLL sums and perfectly calibrated sigmas, w=1 would be the pure Bayesian answer. However empirical sweep on 50 structures (1.15-3.0 A) showed the monomer library sigmas lead to too-loose geometry at w=1. The optimal base is:

w_geometry ~= 10 (from median Rfree minimum) w_adp ~= 1 (lower is better for most resolutions)

Geometry and ADP have different base weights because the monomer library geometry sigmas are tighter than the ADP restraint sigmas are loose, so they need different compensation.

Optional resolution dependence:

w_geometry = base_w_geometry * (d_min / d_ref) ^ alpha w_adp = base_w_adp * (d_min / d_ref) ^ alpha

alpha=0 disables resolution correction. The sweep found the resolution dependence was weak (<20% effect), so alpha=0 is the default.

Parameters:
  • device (torch.device, optional) – Computation device.

  • base_w_geometry (float, optional) – Base geometry weight. Default 10.0 (from empirical sweep).

  • base_w_adp (float, optional) – Base ADP weight. Default 1.0 (from empirical sweep).

  • d_ref (float, optional) – Reference resolution for the optional power-law correction. Default 2.0 A.

  • alpha (float, optional) – Resolution sensitivity. Default 0.0 (disabled).

name: str = 'resolution_weighting'
__init__(device=None, base_w_geometry=1.0, base_w_adp=1.0, d_ref=2.0, alpha=0.0)[source]
forward(state)[source]

Compute base weights with optional resolution correction.

stats(state=None)[source]

Return statistics for reporting.

Parameters:

state (LossState, optional) – If provided, can pull data from LossState.

Returns:

Statistics dictionary with StatEntry objects.

Return type:

Dict[str, StatEntry]

class torchref.refinement.weighting.component_weighting.OverfittingWeighting(device=None, target_gap=0.05, min_weight=1.0, sharpness=30.0, geom_share=1.0, smoothing=0.8)[source]

Bases: BaseWeighting

Dynamic overfitting correction based on Rfree - Rwork gap.

Uses R-factors (scale-invariant, per-reflection-normalized) rather than NLL values, which are incomparable between work/test sets after switching to summed NLLs.

When the Rfree-Rwork gap exceeds target_gap, exponentially increases regularization. The correction is applied primarily to ADP weights and only weakly to geometry weights: in crystallographic refinement, overfitting is typically driven by B-factors (which have one parameter per atom and relatively weak restraints) rather than coordinates (which are held tightly by the geometry prior).

Effective correction:

factor = min_weight + exp(sharpness * (gap - target_gap)) w_adp *= factor w_geometry *= 1 + geom_share * (factor - 1)

With geom_share = 0.2, only 20% of the overfitting correction is applied to geometry, keeping most of the effect on ADP.

Tunable parameters (as buffers): - target_gap: R-factor gap threshold. Default 0.05 (5%). - min_weight: base correction factor. Default 1.0. - sharpness: exponential response steepness. Default 30.0. - geom_share: fraction of correction applied to geometry. Default 1.0. - smoothing: EMA smoothing factor (0-1). Default 0.8.

name: str = 'overfitting_weighting'
__init__(device=None, target_gap=0.05, min_weight=1.0, sharpness=30.0, geom_share=1.0, smoothing=0.8)[source]
forward(state)[source]

Compute overfitting correction weights from R-factor gap.

stats(state=None)[source]

Return statistics for reporting.

Parameters:

state (LossState, optional) – If provided, can pull data from LossState.

Returns:

Statistics dictionary with StatEntry objects.

Return type:

Dict[str, StatEntry]

class torchref.refinement.weighting.component_weighting.ManualWeighting(weights, device=None)[source]

Bases: BaseWeighting

Apply fixed manual weights.

This scheme doesn’t need any state data - just returns the present weights.

name: str = 'manual_weighting'
__init__(weights, device=None)[source]
forward(state)[source]

Return manual weights (state is not used).

class torchref.refinement.weighting.component_weighting.ComponentWeighting(device=None, weights=None, component_weights=None, schemes=None, initial_xray_loss=None)[source]

Bases: DeviceMixin, Module

Combines multiple weighting schemes using nn.ModuleDict.

Holds weighting schemes but does NOT hold a refinement reference. Weighting is computed via forward(state) which receives a LossState.

Default schemes: - ‘resolution’: ResolutionWeighting - resolution-dependent prior strength - ‘overfitting’: OverfittingWeighting - prevents overfitting via Rfree gap

Parameters:
  • device (torch.device, optional) – Computation device.

  • weights (dict, optional) – Manual weight overrides.

  • component_weights (dict, optional) – Manual weight overrides for specific components.

  • schemes (list of BaseWeighting, optional) – Additional custom weighting schemes.

schemes

Dictionary of weighting schemes.

Type:

nn.ModuleDict

__init__(device=None, weights=None, component_weights=None, schemes=None, initial_xray_loss=None)[source]
__getitem__(key)[source]

Get a scheme by name using dictionary-style access.

__contains__(key)[source]

Check if a scheme exists.

keys()[source]

Return scheme names.

values()[source]

Return scheme instances.

items()[source]

Return (name, scheme) pairs.

add_scheme(name, scheme)[source]

Add a new weighting scheme.

forward(state)[source]

Compute weights from all schemes.

Returns combined weights (multiplicative for shared keys). Does NOT modify state - just returns the computed weights.

total_loss_from_state(state)[source]

Compute total weighted loss from a LossState.

stats(state=None)[source]

Return statistics for reporting.