torchref.config module

Centralized configuration for TorchRef.

Default dtypes can be set via environment variables at import time: - TORCHREF_DTYPE_FLOAT: float32 (default) or float64 - TORCHREF_DTYPE_INT: int32 (default) or int64 - TORCHREF_DTYPE_COMPLEX: complex64 (default) or complex128

Default device is auto-detected at import time using cuda -> mps -> cpu. A CUDA device is only picked automatically if it satisfies both:

  • compute capability >= the minimum sm_* compiled into the current PyTorch build (torch.cuda.get_arch_list()), and

  • total VRAM >= _MIN_CUDA_VRAM_GB (10 GB).

Otherwise auto-detection falls back to MPS or CPU with a warning that names the failing requirement. Override the resolved device with the TORCHREF_DEVICE environment variable (‘auto’ (default), ‘cuda’, ‘mps’, ‘cpu’); an explicit value bypasses the capability/VRAM gates but still fails fast if the requested backend is unavailable on this host.

Users can also change dtypes/device at runtime via attribute assignment:

import torchref torchref.dtypes.float = torch.float64 torchref.device.current = torch.device(‘cpu’)

Or read current values:

torchref.dtypes.float # torch.float32 torchref.device.current # torch.device(‘cuda’)

MPS caveat: Apple’s MPS backend does not support float64 / complex128. If the resolved device is MPS and the configured float dtype is float64, a warning is emitted at import time. Either set TORCHREF_DTYPE_FLOAT=float32 or TORCHREF_DEVICE=cpu to silence it.

class torchref.config.DtypeConfig[source]

Bases: object

Dtype configuration with property-based access.

Access dtypes as attributes:

dtypes.float # get current float dtype dtypes.int # get current int dtype dtypes.complex # get current complex dtype

Set dtypes via assignment:

dtypes.float = torch.float64 dtypes.int = torch.int64 dtypes.complex = torch.complex128

__init__()[source]
property float: dtype

Get the current default float dtype.

property int: dtype

Get the current default int dtype.

property complex: dtype

Get the current default complex dtype.

torchref.config.get_float_dtype()[source]

Get the current default float dtype.

torchref.config.get_int_dtype()[source]

Get the current default int dtype.

torchref.config.get_complex_dtype()[source]

Get the current default complex dtype.

class torchref.config.DeviceConfig[source]

Bases: object

Device configuration with property-based access.

Resolved once at import time using cuda -> mps -> cpu via _auto_detect_device(), which gates CUDA on both compute capability and a minimum of _MIN_CUDA_VRAM_GB of VRAM. Override the resolved default via the TORCHREF_DEVICE environment variable ('auto' (default), 'cuda', 'mps', or 'cpu'); explicit values bypass the auto-selection gates and instead raise if the requested backend is unavailable on this host.

device.current # get the active device device.current = “cpu” # set at runtime (string or torch.device)

Setter behaviour mirrors the env-var override: a bad value raises ValueError / RuntimeError rather than silently falling back, so callers can decide how to recover.

__init__()[source]
property current: device

Get the current default device.

torchref.config.get_default_device()[source]

Get the current default device.