Source code for torchref.utils.device_mixin

"""
DeviceMixin - unified device and dtype movement for TorchRef.

Provides a single mixin that hijacks ``.to()``, ``.cuda()``, ``.cpu()`` (and
indirectly ``.float()``, ``.double()``, ``.half()`` since they all funnel
through ``nn.Module._apply``) for both ``nn.Module`` subclasses and plain
Python classes. The mixin recursively traverses the object graph, moving:

* parameters, buffers and child ``nn.Module`` instances (via the standard
  ``nn.Module._apply`` machinery when applicable),
* raw ``torch.Tensor`` attributes stored directly on ``self``,
* non-Module sub-objects that expose ``_apply``,
* tensors nested inside ``list`` / ``tuple`` / ``dict`` attributes,
* unregistered ``nn.Module`` instances held as plain attributes.

A thread-local visited set keyed by ``id()`` makes traversal cycle-safe
(e.g. ``Target.refinement -> Refinement.targets -> Target``). After moving,
any node exposing ``reset_forward_cache()`` or ``reset_cache()`` is
invalidated.

Usage::

    class MyModule(DeviceMixin, nn.Module):
        ...  # nothing else required for device movement

    @dataclass
    class MyDataclass(DeviceMixin):
        _data: torch.Tensor

The legacy name ``DeviceMovementMixin`` is kept as an alias for
``DeviceMixin``.
"""

from __future__ import annotations

import threading

import torch
from torch import nn

# ---------------------------------------------------------------------------
# Thread-local traversal state (cycle detection across one top-level .to())
# ---------------------------------------------------------------------------

_traversal_state = threading.local()

# Attribute names on nn.Module that hold params / buffers / children and are
# already moved by ``nn.Module._apply``. Skipping them in the __dict__ walk
# avoids double-work and pointless recursion into the bookkeeping dicts.
_NN_MODULE_INTERNALS = frozenset(
    {
        "_parameters",
        "_buffers",
        "_modules",
        "_non_persistent_buffers_set",
        "_backward_hooks",
        "_backward_pre_hooks",
        "_forward_hooks",
        "_forward_hooks_with_kwargs",
        "_forward_pre_hooks",
        "_forward_pre_hooks_with_kwargs",
        "_state_dict_hooks",
        "_state_dict_pre_hooks",
        "_load_state_dict_pre_hooks",
        "_load_state_dict_post_hooks",
        "_is_full_backward_hook",
        "training",
    }
)


def _current_visited():
    """Return the active visited set, or ``None`` if not in a traversal."""
    return getattr(_traversal_state, "visited", None)


def _enter_traversal():
    """Begin a top-level traversal. Returns a token for :func:`_exit_traversal`.

    Nested calls reuse the existing visited set and return ``None`` so the
    outermost caller is the only one that tears it down.
    """
    prev = getattr(_traversal_state, "visited", None)
    if prev is None:
        _traversal_state.visited = set()
        return "owner"
    return None


def _exit_traversal(token):
    """End a top-level traversal."""
    if token == "owner":
        _traversal_state.visited = None


def _parse_to_args(args, kwargs):
    """Parse loose ``.to()`` arguments into ``(device, dtype)``.

    Supports the common forms used in TorchRef and in ``torch.Tensor.to``:
    positional ``torch.dtype`` / device-like / ``None``, and the explicit
    keyword arguments ``device=`` and ``dtype=``.
    """
    device = kwargs.get("device", None)
    dtype = kwargs.get("dtype", None)
    for a in args:
        if isinstance(a, torch.dtype):
            dtype = a
        elif isinstance(a, torch.device):
            device = a
        elif isinstance(a, str):
            device = a
        elif isinstance(a, int):
            device = a
    return device, dtype


def _apply_to_obj(val, fn, visited):
    """Apply ``fn`` to any tensors inside ``val``, recursing as needed.

    Returns the (possibly new) value to store back. ``nn.Module`` children
    that have already been visited (registered submodules of the current
    parent) are skipped; unregistered ``nn.Module`` attributes are traversed.
    """
    if isinstance(val, torch.Tensor):
        # Do NOT short-circuit on visited for tensors: the same source
        # tensor may be aliased from multiple attribute slots (e.g. a
        # ``Cell._data`` referenced both as ``cell._data`` and as
        # ``submodule.cell_params``). Each slot needs an independent
        # ``fn(val)`` invocation so its attribute is updated to the
        # moved tensor. Re-applying ``fn`` to an already-converted
        # tensor is a cheap no-op when the target matches.
        return fn(val)

    if isinstance(val, nn.Module):
        if id(val) in visited:
            return val
        # Do NOT add ``id(val)`` to ``visited`` here — ``val._apply`` does
        # that itself. Adding it first would make the inner call short-circuit
        # before it actually moves the module's tensors.
        val._apply(fn)
        return val

    # Non-Module sub-objects that implement the _apply contract.
    apply_method = getattr(val, "_apply", None)
    if callable(apply_method) and not isinstance(val, type):
        if id(val) in visited:
            return val
        try:
            apply_method(fn)
            return val
        except TypeError:
            # Object has an _apply with a different signature; skip silently.
            pass

    if isinstance(val, list):
        new_list = [_apply_to_obj(v, fn, visited) for v in val]
        if all(a is b for a, b in zip(new_list, val)):
            return val
        return new_list

    if isinstance(val, tuple):
        new_tuple = tuple(_apply_to_obj(v, fn, visited) for v in val)
        if all(a is b for a, b in zip(new_tuple, val)):
            return val
        return new_tuple

    if isinstance(val, dict):
        replaced = False
        new_dict = {}
        for k, v in val.items():
            nv = _apply_to_obj(v, fn, visited)
            if nv is not v:
                replaced = True
            new_dict[k] = nv
        return new_dict if replaced else val

    return val


def _invalidate_caches(obj):
    """Call ``reset_forward_cache`` / ``reset_cache`` if present."""
    reset_fwd = getattr(obj, "reset_forward_cache", None)
    if callable(reset_fwd):
        try:
            reset_fwd()
        except Exception:
            pass
    reset_full = getattr(obj, "reset_cache", None)
    if callable(reset_full):
        try:
            reset_full()
        except Exception:
            pass


def _representative_tensor(obj):
    """Return one tensor that reflects the current device/dtype of *obj*.

    Prefers buffers (their dtype tracks ``dtype_float`` for crystallographic
    code) and falls back to parameters, then to a plain ``torch.Tensor``
    attribute found in ``__dict__``.
    """
    if isinstance(obj, nn.Module):
        for buf in obj.buffers():
            return buf
        for param in obj.parameters():
            return param
    for val in obj.__dict__.values():
        if isinstance(val, torch.Tensor):
            return val
        if hasattr(val, "_data") and isinstance(getattr(val, "_data"), torch.Tensor):
            return val._data
    return None


def _refresh_device_trackers(obj):
    """Refresh ``device`` / ``_device`` / ``dtype_float`` / ``_dtype`` trackers.

    Only updates attributes whose current value is already a ``torch.device``
    / ``torch.dtype`` (or ``None`` / ``str`` / ``int``) — never overwrites
    unrelated attributes that happen to share the same name. The tracker
    controls where new tensors are subsequently allocated, so it must follow
    ``.to()`` moves.
    """
    rep = _representative_tensor(obj)
    if rep is None:
        return

    for attr_name in ("device", "_device"):
        if attr_name not in obj.__dict__:
            continue
        current = obj.__dict__[attr_name]
        if current is None or isinstance(current, (torch.device, str, int)):
            obj.__dict__[attr_name] = rep.device

    if rep.is_floating_point() or rep.is_complex():
        for attr_name in ("dtype_float", "_dtype"):
            if attr_name not in obj.__dict__:
                continue
            current = obj.__dict__[attr_name]
            if current is None or isinstance(current, torch.dtype):
                obj.__dict__[attr_name] = rep.dtype


def _safe_setattr(obj, name, value):
    """Update an attribute without triggering nn.Module.__setattr__ side effects."""
    try:
        obj.__dict__[name] = value
    except (AttributeError, TypeError):
        object.__setattr__(obj, name, value)


# ---------------------------------------------------------------------------
# Unified mixin
# ---------------------------------------------------------------------------


[docs] class DeviceMixin: """Unified device/dtype movement. Inherit alongside ``nn.Module`` (place before ``nn.Module`` in the MRO):: class Foo(DeviceMixin, nn.Module): ... Or use on a plain Python class / dataclass:: @dataclass class Bar(DeviceMixin): data: torch.Tensor All of ``.to()``, ``.cuda()``, ``.cpu()``, ``.float()``, ``.double()``, ``.half()`` route through :meth:`_apply`, which: 1. invokes ``nn.Module._apply`` when applicable so parameters, buffers and child modules are moved by the standard PyTorch path, 2. walks ``self.__dict__`` to pick up plain tensor attributes, nested containers and non-Module sub-objects, 3. calls ``reset_forward_cache()`` and ``reset_cache()`` if either is defined. """ # ---- to / cuda / cpu -------------------------------------------------
[docs] def to(self, *args, **kwargs): # type: ignore[override] token = _enter_traversal() try: if isinstance(self, nn.Module): return super().to(*args, **kwargs) device, dtype = _parse_to_args(args, kwargs) if device is None and dtype is None: return self def fn(t): return t.to(device=device, dtype=dtype) return self._apply(fn) finally: _exit_traversal(token)
[docs] def cuda(self, device=None): # type: ignore[override] if device is None: device = "cuda" elif isinstance(device, int): device = f"cuda:{device}" return self.to(device=device)
[docs] def cpu(self): # type: ignore[override] return self.to(device="cpu")
# ---- core traversal -------------------------------------------------- def _apply(self, fn, recurse=True): # type: ignore[override] visited = _current_visited() token = None if visited is None: token = _enter_traversal() visited = _current_visited() try: if id(self) in visited: return self visited.add(id(self)) # 1. Standard nn.Module traversal (params, buffers, child modules). if isinstance(self, nn.Module): try: super()._apply(fn, recurse=recurse) except TypeError: super()._apply(fn) # Mark registered children as visited so that other plain # attributes / back-references pointing at them do not trigger # redundant traversal in step 2. for child in self.children(): visited.add(id(child)) # 2. __dict__ walk for plain tensors and non-Module sub-objects. for name, val in list(self.__dict__.items()): if name in _NN_MODULE_INTERNALS: continue new_val = _apply_to_obj(val, fn, visited) if new_val is not val: _safe_setattr(self, name, new_val) # 3. Refresh ``device`` / ``_device`` / ``dtype`` trackers so # subsequent tensor allocations target the new device. _refresh_device_trackers(self) # 4. Invalidate caches. _invalidate_caches(self) return self finally: if token is not None: _exit_traversal(token)
# --------------------------------------------------------------------------- # Backwards-compatibility aliases # --------------------------------------------------------------------------- # Older code imports ``DeviceMovementMixin``; keep it pointing at the # active implementation so existing classes pick up the new behaviour. DeviceMovementMixin = DeviceMixin # ``_NonModuleDeviceMixin`` was briefly distinct; the unified mixin now # handles both Module and non-Module cases. _NonModuleDeviceMixin = DeviceMixin