torchref.utils.hyperparameters module
Hyperparameter Registration for PyTorch Modules.
This module provides a mechanism for tracking hyperparameters separately from parameters (trainable weights) and buffers (persistent state). This allows for:
Easy access to all hyperparameters via module.hyperparameters()
Separate state_dict for hyperparameters via module.hyperparameter_state_dict()
Clean separation between model weights, state, and configuration
Design Pattern: - Hyperparameters are registered as buffers internally (for device tracking) - They are also tracked in a separate _hyperparameters set - This allows filtering them out or including them as needed
- Usage:
- class MyModule(HyperparameterMixin, nn.Module):
- def __init__(self):
super().__init__() self.register_hyperparameter(‘learning_rate’, 0.01) self.register_hyperparameter(‘sigma’, 2.0)
# Access all hyperparameters for name, value in module.hyperparameters():
print(f”{name}: {value}”)
# Get just hyperparameter state dict hp_state = module.hyperparameter_state_dict()
# Load hyperparameters module.load_hyperparameter_state_dict(hp_state)
- class torchref.utils.hyperparameters.HyperparameterMixin[source]
Bases:
objectMixin class providing hyperparameter registration and tracking.
Use this with nn.Module to add hyperparameter tracking capabilities. Hyperparameters are stored as buffers but tracked separately.
Examples
class MyTarget(HyperparameterMixin, nn.Module): def __init__(self, sigma=1.0, target_value=0.0): nn.Module.__init__(self) HyperparameterMixin.__init__(self) self.register_hyperparameter('sigma', sigma) self.register_hyperparameter('target_value', target_value) target = MyTarget(sigma=2.5) dict(target.hyperparameters()) # {'sigma': tensor(2.5), 'target_value': tensor(0.0)}
- register_hyperparameter(name, value, persistent=True)[source]
Register a hyperparameter.
Hyperparameters are stored as buffers (so they move with .to(device)) but tracked separately for easy access and state_dict operations.
- Parameters:
Examples
self.register_hyperparameter('sigma', 2.0) self.sigma # Access via property (if defined) or _hp_sigma # 2.0
- get_hyperparameter(name)[source]
Get a hyperparameter value.
- Parameters:
name (str) – Name of the hyperparameter.
- Returns:
The hyperparameter tensor.
- Return type:
- hyperparameters(recurse=True)[source]
Return an iterator over module hyperparameters.
- Parameters:
recurse (bool, optional) – If True, include hyperparameters from submodules. Default is True.
- Yields:
tuple – (name, tensor) pairs for each hyperparameter.
Examples
for name, hp in module.hyperparameters(): print(f"{name}: {hp.item()}")
- named_hyperparameters(prefix='', recurse=True)[source]
Return an iterator over module hyperparameters with names.
Same as hyperparameters() but allows prefix specification.
- hyperparameter_state_dict(prefix='')[source]
Return a dictionary containing only hyperparameters.
Unlike state_dict() which includes all buffers and parameters, this only returns the hyperparameters.
- Parameters:
prefix (str, optional) – Prefix to prepend to names. Default is ‘’.
- Returns:
Dictionary mapping hyperparameter names to tensors.
- Return type:
Examples
hp_state = module.hyperparameter_state_dict() torch.save(hp_state, 'hyperparameters.pt')
- load_hyperparameter_state_dict(state_dict, strict=True)[source]
Load hyperparameters from a state dict.
- Parameters:
Examples
hp_state = torch.load('hyperparameters.pt') module.load_hyperparameter_state_dict(hp_state)
- hyperparameter_dict()[source]
Return hyperparameters as a simple Python dict of floats.
Useful for logging, serialization to JSON, etc.
- Returns:
Dictionary mapping hyperparameter names to float values.
- Return type:
Examples
params = module.hyperparameter_dict() import json json.dumps(params) # JSON serializable
- torchref.utils.hyperparameters.create_hyperparameter_property(name)[source]
Create a property for convenient hyperparameter access.
Use this to create getter/setter properties that access the underlying registered hyperparameter.
- Parameters:
name (str) – Name of the hyperparameter.
- Returns:
A property object for the hyperparameter.
- Return type:
Examples
class MyModule(HyperparameterMixin, nn.Module): sigma = create_hyperparameter_property('sigma') def __init__(self, sigma=1.0): super().__init__() self.register_hyperparameter('sigma', sigma) m = MyModule(sigma=2.5) m.sigma # Uses the property # 2.5 m.sigma = 3.0 # Sets via property