torchref.utils.device_mixin module
DeviceMixin - unified device and dtype movement for TorchRef.
Provides a single mixin that hijacks .to(), .cuda(), .cpu() (and
indirectly .float(), .double(), .half() since they all funnel
through nn.Module._apply) for both nn.Module subclasses and plain
Python classes. The mixin recursively traverses the object graph, moving:
parameters, buffers and child
nn.Moduleinstances (via the standardnn.Module._applymachinery when applicable),raw
torch.Tensorattributes stored directly onself,non-Module sub-objects that expose
_apply,tensors nested inside
list/tuple/dictattributes,unregistered
nn.Moduleinstances held as plain attributes.
A thread-local visited set keyed by id() makes traversal cycle-safe
(e.g. Target.refinement -> Refinement.targets -> Target). After moving,
any node exposing reset_forward_cache() or reset_cache() is
invalidated.
Usage:
class MyModule(DeviceMixin, nn.Module):
... # nothing else required for device movement
@dataclass
class MyDataclass(DeviceMixin):
_data: torch.Tensor
The legacy name DeviceMovementMixin is kept as an alias for
DeviceMixin.
- class torchref.utils.device_mixin.DeviceMixin[source]
Bases:
objectUnified device/dtype movement.
Inherit alongside
nn.Module(place beforenn.Modulein the MRO):class Foo(DeviceMixin, nn.Module): ...
Or use on a plain Python class / dataclass:
@dataclass class Bar(DeviceMixin): data: torch.Tensor
All of
.to(),.cuda(),.cpu(),.float(),.double(),.half()route through_apply(), which:invokes
nn.Module._applywhen applicable so parameters, buffers and child modules are moved by the standard PyTorch path,walks
self.__dict__to pick up plain tensor attributes, nested containers and non-Module sub-objects,calls
reset_forward_cache()andreset_cache()if either is defined.
- torchref.utils.device_mixin.DeviceMovementMixin
alias of
DeviceMixin