Source code for torchref.config

"""
Centralized configuration for TorchRef.

Default dtypes can be set via environment variables at import time:
- TORCHREF_DTYPE_FLOAT: float32 (default) or float64
- TORCHREF_DTYPE_INT: int32 (default) or int64
- TORCHREF_DTYPE_COMPLEX: complex64 (default) or complex128

Default device is auto-detected at import time using cuda -> mps -> cpu.
A CUDA device is only picked automatically if it satisfies *both*:
  * compute capability >= the minimum sm_* compiled into the current
    PyTorch build (``torch.cuda.get_arch_list()``), and
  * total VRAM >= ``_MIN_CUDA_VRAM_GB`` (10 GB).
Otherwise auto-detection falls back to MPS or CPU with a warning that
names the failing requirement. Override the resolved device with the
TORCHREF_DEVICE environment variable ('auto' (default), 'cuda', 'mps',
'cpu'); an explicit value bypasses the capability/VRAM gates but still
fails fast if the requested backend is unavailable on this host.

Users can also change dtypes/device at runtime via attribute assignment:
    import torchref
    torchref.dtypes.float = torch.float64
    torchref.device.current = torch.device('cpu')

Or read current values:
    torchref.dtypes.float   # torch.float32
    torchref.device.current # torch.device('cuda')

MPS caveat: Apple's MPS backend does not support float64 / complex128. If
the resolved device is MPS and the configured float dtype is float64, a
warning is emitted at import time. Either set TORCHREF_DTYPE_FLOAT=float32
or TORCHREF_DEVICE=cpu to silence it.
"""

import os
import warnings

import torch

# Map strings to torch dtypes
_FLOAT_DTYPE_MAP = {
    "float32": torch.float32,
    "float64": torch.float64,
}

_INT_DTYPE_MAP = {
    "int32": torch.int32,
    "int64": torch.int64,
}

_COMPLEX_DTYPE_MAP = {
    "complex64": torch.complex64,
    "complex128": torch.complex128,
}


[docs] class DtypeConfig: """ Dtype configuration with property-based access. Access dtypes as attributes: dtypes.float # get current float dtype dtypes.int # get current int dtype dtypes.complex # get current complex dtype Set dtypes via assignment: dtypes.float = torch.float64 dtypes.int = torch.int64 dtypes.complex = torch.complex128 """
[docs] def __init__(self): # Parse environment variables with defaults float_str = os.environ.get("TORCHREF_DTYPE_FLOAT", "float32").lower() int_str = os.environ.get("TORCHREF_DTYPE_INT", "int32").lower() complex_str = os.environ.get("TORCHREF_DTYPE_COMPLEX", "complex64").lower() # Validate and set if float_str not in _FLOAT_DTYPE_MAP: raise ValueError( f"Invalid TORCHREF_DTYPE_FLOAT: {float_str}. " f"Valid values: {list(_FLOAT_DTYPE_MAP.keys())}" ) if int_str not in _INT_DTYPE_MAP: raise ValueError( f"Invalid TORCHREF_DTYPE_INT: {int_str}. " f"Valid values: {list(_INT_DTYPE_MAP.keys())}" ) if complex_str not in _COMPLEX_DTYPE_MAP: raise ValueError( f"Invalid TORCHREF_DTYPE_COMPLEX: {complex_str}. " f"Valid values: {list(_COMPLEX_DTYPE_MAP.keys())}" ) self._float = _FLOAT_DTYPE_MAP[float_str] self._int = _INT_DTYPE_MAP[int_str] self._complex = _COMPLEX_DTYPE_MAP[complex_str]
@property def float(self) -> torch.dtype: """Get the current default float dtype.""" return self._float @float.setter def float(self, dtype: torch.dtype) -> None: """Set the default float dtype for all future operations.""" if dtype not in (torch.float32, torch.float64): raise ValueError( f"Invalid float dtype: {dtype}. Use torch.float32 or torch.float64." ) self._float = dtype @property def int(self) -> torch.dtype: """Get the current default int dtype.""" return self._int @int.setter def int(self, dtype: torch.dtype) -> None: """Set the default int dtype for all future operations.""" if dtype not in (torch.int32, torch.int64): raise ValueError( f"Invalid int dtype: {dtype}. Use torch.int32 or torch.int64." ) self._int = dtype @property def complex(self) -> torch.dtype: """Get the current default complex dtype.""" return self._complex @complex.setter def complex(self, dtype: torch.dtype) -> None: """Set the default complex dtype for all future operations.""" if dtype not in (torch.complex64, torch.complex128): raise ValueError( f"Invalid complex dtype: {dtype}. Use torch.complex64 or torch.complex128." ) self._complex = dtype def __repr__(self) -> str: return f"DtypeConfig(float={self._float}, int={self._int}, complex={self._complex})"
# Global singleton instance dtypes = DtypeConfig() # Convenience functions for internal use (avoid repeated attribute lookups)
[docs] def get_float_dtype() -> torch.dtype: """Get the current default float dtype.""" return dtypes.float
[docs] def get_int_dtype() -> torch.dtype: """Get the current default int dtype.""" return dtypes.int
[docs] def get_complex_dtype() -> torch.dtype: """Get the current default complex dtype.""" return dtypes.complex
# --------------------------------------------------------------------------- # Device configuration # --------------------------------------------------------------------------- _VALID_DEVICE_TYPES = ("cuda", "mps", "cpu") # Minimum GPU VRAM (in GB) required for CUDA to be picked by auto-detection. # Smaller GPUs typically can't fit useful refinement workloads and surprise # users with OOMs, so we fall back to CPU instead. _MIN_CUDA_VRAM_GB = 10 def _cuda_is_usable() -> bool: """Return True iff at least one visible CUDA device is suitable for auto-selection as the default TorchRef device. A device qualifies when all of the following hold: * ``torch.cuda.is_available()`` is True. * Its compute capability is >= the minimum sm_* compiled into the current PyTorch wheel (introspected via ``torch.cuda.get_arch_list()``). Older GPUs would trigger runtime warnings and fail at the first kernel launch. * Its total VRAM is >= ``_MIN_CUDA_VRAM_GB``. Smaller GPUs typically cannot fit useful refinement workloads and tend to surprise users with OOMs, so we prefer CPU over a too-small GPU. If introspection fails on an older torch build (no ``get_arch_list``) or the arch list is empty, we trust ``is_available()`` and return True without the capability check. On failure a single ``warnings.warn`` explains which requirement was missed before auto-detection falls through to MPS or CPU. """ if not torch.cuda.is_available(): return False try: arch_list = torch.cuda.get_arch_list() except Exception: # If we cannot introspect supported archs, fall back to trusting # is_available() (older torch versions). return True if not arch_list: return True # Parse e.g. "sm_70" -> (7, 0). Ignore non-sm entries like "compute_xx". supported = [] for entry in arch_list: if not entry.startswith("sm_"): continue try: num = entry[3:] major = int(num[:-1]) minor = int(num[-1]) supported.append((major, minor)) except (ValueError, IndexError): continue if not supported: return True min_supported = min(supported) min_vram_bytes = _MIN_CUDA_VRAM_GB * (1024**3) for idx in range(torch.cuda.device_count()): try: cap = torch.cuda.get_device_capability(idx) except Exception: continue if cap < min_supported: continue try: total_mem = torch.cuda.get_device_properties(idx).total_memory except Exception: total_mem = 0 if total_mem >= min_vram_bytes: return True warnings.warn( "TorchRef: no detected CUDA GPU meets the auto-selection requirements " f"(compute capability >= {min_supported[0]}.{min_supported[1]} and " f">= {_MIN_CUDA_VRAM_GB} GB VRAM; PyTorch build supports sm_*: " f"{arch_list}). Falling back to CPU. Set TORCHREF_DEVICE=cuda " "explicitly to override.", stacklevel=3, ) return False def _auto_detect_device() -> torch.device: """Pick the best available device: cuda -> mps -> cpu.""" if _cuda_is_usable(): return torch.device("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu")
[docs] class DeviceConfig: """ Device configuration with property-based access. Resolved once at import time using cuda -> mps -> cpu via :func:`_auto_detect_device`, which gates CUDA on both compute capability and a minimum of ``_MIN_CUDA_VRAM_GB`` of VRAM. Override the resolved default via the ``TORCHREF_DEVICE`` environment variable (``'auto'`` (default), ``'cuda'``, ``'mps'``, or ``'cpu'``); explicit values bypass the auto-selection gates and instead raise if the requested backend is unavailable on this host. device.current # get the active device device.current = "cpu" # set at runtime (string or torch.device) Setter behaviour mirrors the env-var override: a bad value raises ``ValueError`` / ``RuntimeError`` rather than silently falling back, so callers can decide how to recover. """
[docs] def __init__(self): override = os.environ.get("TORCHREF_DEVICE", "auto").lower() if override == "auto": self._device = _auto_detect_device() else: self._device = self._coerce(override) self._warn_if_mps_dtype_mismatch()
@staticmethod def _coerce(value) -> torch.device: """Validate and convert a user-supplied device value.""" if isinstance(value, torch.device): dev = value elif isinstance(value, str): dev = torch.device(value) else: raise TypeError( f"device must be a torch.device or string, got {type(value).__name__}" ) if dev.type not in _VALID_DEVICE_TYPES: raise ValueError( f"Invalid device type: {dev.type!r}. " f"Valid types: {_VALID_DEVICE_TYPES}" ) if dev.type == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available on this system.") if dev.type == "mps" and not ( hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ): raise RuntimeError("MPS requested but not available on this system.") return dev def _warn_if_mps_dtype_mismatch(self) -> None: if self._device.type == "mps" and dtypes.float == torch.float64: warnings.warn( "TorchRef default device is MPS but dtypes.float is float64; " "MPS does not support float64. Set TORCHREF_DTYPE_FLOAT=float32 " "or TORCHREF_DEVICE=cpu to silence this warning.", stacklevel=2, ) @property def current(self) -> torch.device: """Get the current default device.""" return self._device @current.setter def current(self, value) -> None: """Set the default device for all future operations.""" self._device = self._coerce(value) self._warn_if_mps_dtype_mismatch() def __repr__(self) -> str: return f"DeviceConfig(current={self._device})"
device = DeviceConfig()
[docs] def get_default_device() -> torch.device: """Get the current default device.""" return device.current