torchref.refinement.weighting.base_weighting module
Base weighting class using LossState for all data access.
All weighting schemes inherit from BaseWeighting and receive their data through LossState rather than direct refinement references.
- class torchref.refinement.weighting.base_weighting.BaseWeighting(device=None, **kwargs)[source]
Bases:
DeviceMixin,Module,ABCAbstract base class for weighting schemes using LossState.
Weighting schemes: - Are set up without state (only device and hyperparameters) - Receive state only when computing weights via forward() - Return weights dict (do not modify state directly)
All tunable parameters should be registered as buffers using register_buffer() so they can be accessed/modified via state_dict notation.
- Parameters:
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
- device
Computation device.
- Type: