torchref.utils.device_resolution module

Centralized device resolution for multi-module constructors.

Many TorchRef constructors accept several device-bearing inputs (model + data, or data + data_reference + model, etc.). Each used to have its own ad-hoc rule for picking a device, which led to silent bugs whenever inputs disagreed (e.g. passing device='cpu' to Scaler while data lived on cuda would still leave the s / bins buffers on cuda).

resolve_device collapses N device sources into one with a fixed, documented precedence.

torchref.utils.device_resolution.resolve_device(*modules, device=None)[source]

Resolve a single device from N device-bearing modules.

Each module must expose .device and accept .to(device) (satisfied by torch.nn.Module and by torchref.utils.DeviceMixin non-Module subclasses such as Cell). None entries are skipped silently so empty-init paths can pass through optional submodules — resolve_device(model, data) works whether or not data is None.

Resolution order

  1. If device is given, every non-None module is moved to it and it is returned. No warning is emitted (the caller has made an explicit choice).

  2. Otherwise, after dropping None entries, if no modules remain, torchref.config.get_default_device() is returned.

  3. The first remaining module’s device is the target. Any other module on a different device is moved to the target and a UserWarning is emitted once for the call.

The “first module wins” rule is intentional: callers express precedence by argument order.

param *modules:

Device-bearing modules. None entries are skipped.

param device:

Explicit override. If provided, all non-None modules are moved to it and it is returned.

type device:

torch.device or str, optional

returns:

The resolved device.

rtype:

torch.device

Examples

Empty call returns the configured default:

>>> resolve_device()
device(type='cpu')

Explicit override moves everything:

>>> resolve_device(model, data, device='cpu')
device(type='cpu')

Auto-reconcile with first-wins precedence:

>>> resolve_device(cuda_model, cpu_data)
device(type='cuda')  # cpu_data has been moved to cuda