Source code for torchref.utils.device_resolution

"""Centralized device resolution for multi-module constructors.

Many TorchRef constructors accept several device-bearing inputs
(``model`` + ``data``, or ``data`` + ``data_reference`` + ``model``,
etc.).  Each used to have its own ad-hoc rule for picking a device,
which led to silent bugs whenever inputs disagreed (e.g.  passing
``device='cpu'`` to ``Scaler`` while ``data`` lived on cuda would
still leave the ``s`` / ``bins`` buffers on cuda).

``resolve_device`` collapses N device sources into one with a fixed,
documented precedence.
"""

from __future__ import annotations

import warnings
from typing import Any, Optional, Union

import torch

from torchref.config import get_default_device


def _canonical(device: torch.device) -> torch.device:
    """Return ``device`` with its default index filled in.

    ``torch.device('cuda') != torch.device('cuda:0')`` even though both
    point to the same physical device.  Materialising an empty tensor
    on the device is the cheapest way to get the canonical form.
    """
    return torch.empty(0, device=device).device


[docs] def resolve_device( *modules: Any, device: Optional[Union[torch.device, str]] = None, ) -> torch.device: """Resolve a single device from N device-bearing modules. Each ``module`` must expose ``.device`` and accept ``.to(device)`` (satisfied by ``torch.nn.Module`` and by ``torchref.utils.DeviceMixin`` non-Module subclasses such as ``Cell``). ``None`` entries are skipped silently so empty-init paths can pass through optional submodules — ``resolve_device(model, data)`` works whether or not ``data`` is ``None``. Resolution order ---------------- 1. If ``device`` is given, every non-``None`` module is moved to it and it is returned. No warning is emitted (the caller has made an explicit choice). 2. Otherwise, after dropping ``None`` entries, if no modules remain, :func:`torchref.config.get_default_device` is returned. 3. The first remaining module's device is the target. Any other module on a different device is moved to the target and a :class:`UserWarning` is emitted once for the call. The "first module wins" rule is intentional: callers express precedence by argument order. Parameters ---------- *modules Device-bearing modules. ``None`` entries are skipped. device : torch.device or str, optional Explicit override. If provided, all non-``None`` modules are moved to it and it is returned. Returns ------- torch.device The resolved device. Examples -------- Empty call returns the configured default:: >>> resolve_device() # doctest: +SKIP device(type='cpu') Explicit override moves everything:: >>> resolve_device(model, data, device='cpu') # doctest: +SKIP device(type='cpu') Auto-reconcile with first-wins precedence:: >>> resolve_device(cuda_model, cpu_data) # doctest: +SKIP device(type='cuda') # cpu_data has been moved to cuda """ if device is not None: resolved = torch.device(device) if not isinstance(device, torch.device) else device for m in modules: if m is not None: m.to(resolved) return resolved present = [m for m in modules if m is not None] if not present: return get_default_device() target = _canonical(present[0].device) inconsistent = [m for m in present[1:] if _canonical(m.device) != target] if inconsistent: device_list = [m.device for m in present] warnings.warn( f"resolve_device: inputs on differing devices {device_list}; " f"moving {len(inconsistent)} module(s) to {target} " "(the first input's device).", stacklevel=2, ) for m in inconsistent: m.to(target) return target