"""
Hyperparameter Registration for PyTorch Modules.
This module provides a mechanism for tracking hyperparameters separately from
parameters (trainable weights) and buffers (persistent state). This allows for:
1. Easy access to all hyperparameters via `module.hyperparameters()`
2. Separate state_dict for hyperparameters via `module.hyperparameter_state_dict()`
3. Clean separation between model weights, state, and configuration
Design Pattern:
- Hyperparameters are registered as buffers internally (for device tracking)
- They are also tracked in a separate `_hyperparameters` set
- This allows filtering them out or including them as needed
Usage:
class MyModule(HyperparameterMixin, nn.Module):
def __init__(self):
super().__init__()
self.register_hyperparameter('learning_rate', 0.01)
self.register_hyperparameter('sigma', 2.0)
# Access all hyperparameters
for name, value in module.hyperparameters():
print(f"{name}: {value}")
# Get just hyperparameter state dict
hp_state = module.hyperparameter_state_dict()
# Load hyperparameters
module.load_hyperparameter_state_dict(hp_state)
"""
from collections import OrderedDict
from typing import Dict, Iterator, Tuple
import torch
[docs]
class HyperparameterMixin:
"""
Mixin class providing hyperparameter registration and tracking.
Use this with nn.Module to add hyperparameter tracking capabilities.
Hyperparameters are stored as buffers but tracked separately.
Examples
--------
::
class MyTarget(HyperparameterMixin, nn.Module):
def __init__(self, sigma=1.0, target_value=0.0):
nn.Module.__init__(self)
HyperparameterMixin.__init__(self)
self.register_hyperparameter('sigma', sigma)
self.register_hyperparameter('target_value', target_value)
target = MyTarget(sigma=2.5)
dict(target.hyperparameters())
# {'sigma': tensor(2.5), 'target_value': tensor(0.0)}
"""
[docs]
def __init__(self):
"""Initialize hyperparameter tracking."""
# Set to track which buffers are hyperparameters
# Use object.__setattr__ to avoid triggering nn.Module's __setattr__
if not hasattr(self, "_hyperparameter_names"):
object.__setattr__(self, "_hyperparameter_names", set())
[docs]
def register_hyperparameter(
self, name: str, value: float, persistent: bool = True
) -> None:
"""
Register a hyperparameter.
Hyperparameters are stored as buffers (so they move with .to(device))
but tracked separately for easy access and state_dict operations.
Parameters
----------
name : str
Name of the hyperparameter. Will be prefixed with '_hp_' internally.
value : float
Initial value of the hyperparameter.
persistent : bool, optional
If True, hyperparameter will be part of module's state_dict.
Default is True.
Examples
--------
::
self.register_hyperparameter('sigma', 2.0)
self.sigma # Access via property (if defined) or _hp_sigma
# 2.0
"""
# Initialize tracking set if needed
if not hasattr(self, "_hyperparameter_names"):
object.__setattr__(self, "_hyperparameter_names", set())
# Internal buffer name
buffer_name = f"_hp_{name}"
# Register as buffer
tensor_value = torch.tensor(value, dtype=torch.float32)
self.register_buffer(buffer_name, tensor_value, persistent=persistent)
# Track as hyperparameter
self._hyperparameter_names.add(name)
[docs]
def get_hyperparameter(self, name: str) -> torch.Tensor:
"""
Get a hyperparameter value.
Parameters
----------
name : str
Name of the hyperparameter.
Returns
-------
torch.Tensor
The hyperparameter tensor.
"""
buffer_name = f"_hp_{name}"
return getattr(self, buffer_name)
[docs]
def set_hyperparameter(self, name: str, value: float) -> None:
"""
Set a hyperparameter value.
Parameters
----------
name : str
Name of the hyperparameter.
value : float
New value.
"""
buffer_name = f"_hp_{name}"
getattr(self, buffer_name).fill_(value)
[docs]
def hyperparameters(
self, recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Return an iterator over module hyperparameters.
Parameters
----------
recurse : bool, optional
If True, include hyperparameters from submodules. Default is True.
Yields
------
tuple
(name, tensor) pairs for each hyperparameter.
Examples
--------
::
for name, hp in module.hyperparameters():
print(f"{name}: {hp.item()}")
"""
# Yield own hyperparameters
if hasattr(self, "_hyperparameter_names"):
for name in self._hyperparameter_names:
buffer_name = f"_hp_{name}"
yield name, getattr(self, buffer_name)
# Recursively yield from submodules
if recurse:
for module_name, module in self.named_modules():
if module is self:
continue
if hasattr(module, "_hyperparameter_names"):
for hp_name in module._hyperparameter_names:
buffer_name = f"_hp_{hp_name}"
full_name = (
f"{module_name}.{hp_name}" if module_name else hp_name
)
yield full_name, getattr(module, buffer_name)
[docs]
def named_hyperparameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Return an iterator over module hyperparameters with names.
Same as hyperparameters() but allows prefix specification.
Parameters
----------
prefix : str, optional
Prefix to prepend to names. Default is ''.
recurse : bool, optional
If True, include hyperparameters from submodules. Default is True.
Yields
------
tuple
(name, tensor) pairs for each hyperparameter.
"""
for name, hp in self.hyperparameters(recurse=recurse):
full_name = f"{prefix}{name}" if prefix else name
yield full_name, hp
[docs]
def hyperparameter_state_dict(self, prefix: str = "") -> Dict[str, torch.Tensor]:
"""
Return a dictionary containing only hyperparameters.
Unlike state_dict() which includes all buffers and parameters,
this only returns the hyperparameters.
Parameters
----------
prefix : str, optional
Prefix to prepend to names. Default is ''.
Returns
-------
dict
Dictionary mapping hyperparameter names to tensors.
Examples
--------
::
hp_state = module.hyperparameter_state_dict()
torch.save(hp_state, 'hyperparameters.pt')
"""
result = OrderedDict()
for name, hp in self.named_hyperparameters(prefix=prefix, recurse=True):
result[name] = hp.clone()
return result
[docs]
def load_hyperparameter_state_dict(
self, state_dict: Dict[str, torch.Tensor], strict: bool = True
) -> None:
"""
Load hyperparameters from a state dict.
Parameters
----------
state_dict : dict
Dictionary of hyperparameter name -> tensor.
strict : bool, optional
If True, raise error on missing/unexpected keys. Default is True.
Examples
--------
::
hp_state = torch.load('hyperparameters.pt')
module.load_hyperparameter_state_dict(hp_state)
"""
# Get all hyperparameter names (with module paths)
own_hp_names = set(name for name, _ in self.named_hyperparameters(recurse=True))
missing_keys = []
unexpected_keys = []
for name, value in state_dict.items():
if name in own_hp_names:
# Find the module and set the value
parts = name.rsplit(".", 1)
if len(parts) == 2:
module_path, hp_name = parts
# Navigate to submodule
module = self
for part in module_path.split("."):
module = getattr(module, part)
module.set_hyperparameter(hp_name, value.item())
else:
hp_name = parts[0]
self.set_hyperparameter(hp_name, value.item())
else:
unexpected_keys.append(name)
# Check for missing
for name in own_hp_names:
if name not in state_dict:
missing_keys.append(name)
if strict:
if missing_keys:
raise RuntimeError(f"Missing hyperparameters: {missing_keys}")
if unexpected_keys:
raise RuntimeError(f"Unexpected hyperparameters: {unexpected_keys}")
[docs]
def hyperparameter_dict(self) -> Dict[str, float]:
"""
Return hyperparameters as a simple Python dict of floats.
Useful for logging, serialization to JSON, etc.
Returns
-------
dict
Dictionary mapping hyperparameter names to float values.
Examples
--------
::
params = module.hyperparameter_dict()
import json
json.dumps(params) # JSON serializable
"""
return {name: hp.item() for name, hp in self.hyperparameters(recurse=True)}
[docs]
def print_hyperparameters(self, prefix: str = "") -> None:
"""
Print all hyperparameters in a formatted way.
Parameters
----------
prefix : str, optional
Prefix for indentation. Default is ''.
"""
print(f"{prefix}Hyperparameters:")
for name, hp in sorted(self.hyperparameters(recurse=True)):
print(f"{prefix} {name}: {hp.item():.6g}")
[docs]
def create_hyperparameter_property(name: str) -> property:
"""
Create a property for convenient hyperparameter access.
Use this to create getter/setter properties that access
the underlying registered hyperparameter.
Parameters
----------
name : str
Name of the hyperparameter.
Returns
-------
property
A property object for the hyperparameter.
Examples
--------
::
class MyModule(HyperparameterMixin, nn.Module):
sigma = create_hyperparameter_property('sigma')
def __init__(self, sigma=1.0):
super().__init__()
self.register_hyperparameter('sigma', sigma)
m = MyModule(sigma=2.5)
m.sigma # Uses the property
# 2.5
m.sigma = 3.0 # Sets via property
"""
def getter(self):
return self.get_hyperparameter(name).item()
def setter(self, value):
self.set_hyperparameter(name, value)
return property(getter, setter, doc=f"Hyperparameter: {name}")