torchref.utils.device_mixin module

DeviceMixin - unified device and dtype movement for TorchRef.

Provides a single mixin that hijacks .to(), .cuda(), .cpu() (and indirectly .float(), .double(), .half() since they all funnel through nn.Module._apply) for both nn.Module subclasses and plain Python classes. The mixin recursively traverses the object graph, moving:

  • parameters, buffers and child nn.Module instances (via the standard nn.Module._apply machinery when applicable),

  • raw torch.Tensor attributes stored directly on self,

  • non-Module sub-objects that expose _apply,

  • tensors nested inside list / tuple / dict attributes,

  • unregistered nn.Module instances held as plain attributes.

A thread-local visited set keyed by id() makes traversal cycle-safe (e.g. Target.refinement -> Refinement.targets -> Target). After moving, any node exposing reset_forward_cache() or reset_cache() is invalidated.

Usage:

class MyModule(DeviceMixin, nn.Module):
    ...  # nothing else required for device movement

@dataclass
class MyDataclass(DeviceMixin):
    _data: torch.Tensor

The legacy name DeviceMovementMixin is kept as an alias for DeviceMixin.

class torchref.utils.device_mixin.DeviceMixin[source]

Bases: object

Unified device/dtype movement.

Inherit alongside nn.Module (place before nn.Module in the MRO):

class Foo(DeviceMixin, nn.Module):
    ...

Or use on a plain Python class / dataclass:

@dataclass
class Bar(DeviceMixin):
    data: torch.Tensor

All of .to(), .cuda(), .cpu(), .float(), .double(), .half() route through _apply(), which:

  1. invokes nn.Module._apply when applicable so parameters, buffers and child modules are moved by the standard PyTorch path,

  2. walks self.__dict__ to pick up plain tensor attributes, nested containers and non-Module sub-objects,

  3. calls reset_forward_cache() and reset_cache() if either is defined.

to(*args, **kwargs)[source]
cuda(device=None)[source]
cpu()[source]
torchref.utils.device_mixin.DeviceMovementMixin

alias of DeviceMixin