torchref.utils.device_resolution module
Centralized device resolution for multi-module constructors.
Many TorchRef constructors accept several device-bearing inputs
(model + data, or data + data_reference + model,
etc.). Each used to have its own ad-hoc rule for picking a device,
which led to silent bugs whenever inputs disagreed (e.g. passing
device='cpu' to Scaler while data lived on cuda would
still leave the s / bins buffers on cuda).
resolve_device collapses N device sources into one with a fixed,
documented precedence.
- torchref.utils.device_resolution.resolve_device(*modules, device=None)[source]
Resolve a single device from N device-bearing modules.
Each
modulemust expose.deviceand accept.to(device)(satisfied bytorch.nn.Moduleand bytorchref.utils.DeviceMixinnon-Module subclasses such asCell).Noneentries are skipped silently so empty-init paths can pass through optional submodules —resolve_device(model, data)works whether or notdataisNone.Resolution order
If
deviceis given, every non-Nonemodule is moved to it and it is returned. No warning is emitted (the caller has made an explicit choice).Otherwise, after dropping
Noneentries, if no modules remain,torchref.config.get_default_device()is returned.The first remaining module’s device is the target. Any other module on a different device is moved to the target and a
UserWarningis emitted once for the call.
The “first module wins” rule is intentional: callers express precedence by argument order.
- param *modules:
Device-bearing modules.
Noneentries are skipped.- param device:
Explicit override. If provided, all non-
Nonemodules are moved to it and it is returned.- type device:
torch.device or str, optional
- returns:
The resolved device.
- rtype:
torch.device
Examples
Empty call returns the configured default:
>>> resolve_device() device(type='cpu')
Explicit override moves everything:
>>> resolve_device(model, data, device='cpu') device(type='cpu')
Auto-reconcile with first-wins precedence:
>>> resolve_device(cuda_model, cpu_data) device(type='cuda') # cpu_data has been moved to cuda