torchref.refinement.weighting.base_weighting module

Base weighting class using LossState for all data access.

All weighting schemes inherit from BaseWeighting and receive their data through LossState rather than direct refinement references.

class torchref.refinement.weighting.base_weighting.BaseWeighting(device=None, **kwargs)[source]

Bases: DeviceMixin, Module, ABC

Abstract base class for weighting schemes using LossState.

Weighting schemes: - Are set up without state (only device and hyperparameters) - Receive state only when computing weights via forward() - Return weights dict (do not modify state directly)

All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.

Parameters:

device (torch.device, optional) – Computation device. Defaults to the configured device.current.

name

Unique name for this weighting scheme.

Type:

str

device

Computation device.

Type:

torch.device

name: str = 'base_weighting'
__init__(device=None, **kwargs)[source]
abstractmethod forward(state)[source]

Compute weights from the current LossState.

Access data via state[“key”] or state.get(“key”, default). Meta data (rwork, rfree, etc.) is in state.meta. Cached losses are in state._losses.

Parameters:

state (LossState) – Current loss state with meta and _losses populated.

Returns:

Dictionary mapping component names to weights.

Return type:

Dict[str, float]

stats(state=None)[source]

Return statistics for reporting.

Parameters:

state (LossState, optional) – If provided, can pull data from LossState.

Returns:

Statistics dictionary with StatEntry objects.

Return type:

Dict[str, StatEntry]