torchref.utils.hyperparameters module

Hyperparameter Registration for PyTorch Modules.

This module provides a mechanism for tracking hyperparameters separately from parameters (trainable weights) and buffers (persistent state). This allows for:

  1. Easy access to all hyperparameters via module.hyperparameters()

  2. Separate state_dict for hyperparameters via module.hyperparameter_state_dict()

  3. Clean separation between model weights, state, and configuration

Design Pattern: - Hyperparameters are registered as buffers internally (for device tracking) - They are also tracked in a separate _hyperparameters set - This allows filtering them out or including them as needed

Usage:
class MyModule(HyperparameterMixin, nn.Module):
def __init__(self):

super().__init__() self.register_hyperparameter(‘learning_rate’, 0.01) self.register_hyperparameter(‘sigma’, 2.0)

# Access all hyperparameters for name, value in module.hyperparameters():

print(f”{name}: {value}”)

# Get just hyperparameter state dict hp_state = module.hyperparameter_state_dict()

# Load hyperparameters module.load_hyperparameter_state_dict(hp_state)

class torchref.utils.hyperparameters.HyperparameterMixin[source]

Bases: object

Mixin class providing hyperparameter registration and tracking.

Use this with nn.Module to add hyperparameter tracking capabilities. Hyperparameters are stored as buffers but tracked separately.

Examples

class MyTarget(HyperparameterMixin, nn.Module):
    def __init__(self, sigma=1.0, target_value=0.0):
        nn.Module.__init__(self)
        HyperparameterMixin.__init__(self)
        self.register_hyperparameter('sigma', sigma)
        self.register_hyperparameter('target_value', target_value)

target = MyTarget(sigma=2.5)
dict(target.hyperparameters())
# {'sigma': tensor(2.5), 'target_value': tensor(0.0)}
__init__()[source]

Initialize hyperparameter tracking.

register_hyperparameter(name, value, persistent=True)[source]

Register a hyperparameter.

Hyperparameters are stored as buffers (so they move with .to(device)) but tracked separately for easy access and state_dict operations.

Parameters:
  • name (str) – Name of the hyperparameter. Will be prefixed with ‘_hp_’ internally.

  • value (float) – Initial value of the hyperparameter.

  • persistent (bool, optional) – If True, hyperparameter will be part of module’s state_dict. Default is True.

Examples

self.register_hyperparameter('sigma', 2.0)
self.sigma  # Access via property (if defined) or _hp_sigma
# 2.0
get_hyperparameter(name)[source]

Get a hyperparameter value.

Parameters:

name (str) – Name of the hyperparameter.

Returns:

The hyperparameter tensor.

Return type:

torch.Tensor

set_hyperparameter(name, value)[source]

Set a hyperparameter value.

Parameters:
  • name (str) – Name of the hyperparameter.

  • value (float) – New value.

hyperparameters(recurse=True)[source]

Return an iterator over module hyperparameters.

Parameters:

recurse (bool, optional) – If True, include hyperparameters from submodules. Default is True.

Yields:

tuple – (name, tensor) pairs for each hyperparameter.

Examples

for name, hp in module.hyperparameters():
    print(f"{name}: {hp.item()}")
named_hyperparameters(prefix='', recurse=True)[source]

Return an iterator over module hyperparameters with names.

Same as hyperparameters() but allows prefix specification.

Parameters:
  • prefix (str, optional) – Prefix to prepend to names. Default is ‘’.

  • recurse (bool, optional) – If True, include hyperparameters from submodules. Default is True.

Yields:

tuple – (name, tensor) pairs for each hyperparameter.

hyperparameter_state_dict(prefix='')[source]

Return a dictionary containing only hyperparameters.

Unlike state_dict() which includes all buffers and parameters, this only returns the hyperparameters.

Parameters:

prefix (str, optional) – Prefix to prepend to names. Default is ‘’.

Returns:

Dictionary mapping hyperparameter names to tensors.

Return type:

dict

Examples

hp_state = module.hyperparameter_state_dict()
torch.save(hp_state, 'hyperparameters.pt')
load_hyperparameter_state_dict(state_dict, strict=True)[source]

Load hyperparameters from a state dict.

Parameters:
  • state_dict (dict) – Dictionary of hyperparameter name -> tensor.

  • strict (bool, optional) – If True, raise error on missing/unexpected keys. Default is True.

Examples

hp_state = torch.load('hyperparameters.pt')
module.load_hyperparameter_state_dict(hp_state)
hyperparameter_dict()[source]

Return hyperparameters as a simple Python dict of floats.

Useful for logging, serialization to JSON, etc.

Returns:

Dictionary mapping hyperparameter names to float values.

Return type:

dict

Examples

params = module.hyperparameter_dict()
import json
json.dumps(params)  # JSON serializable
print_hyperparameters(prefix='')[source]

Print all hyperparameters in a formatted way.

Parameters:

prefix (str, optional) – Prefix for indentation. Default is ‘’.

torchref.utils.hyperparameters.create_hyperparameter_property(name)[source]

Create a property for convenient hyperparameter access.

Use this to create getter/setter properties that access the underlying registered hyperparameter.

Parameters:

name (str) – Name of the hyperparameter.

Returns:

A property object for the hyperparameter.

Return type:

property

Examples

class MyModule(HyperparameterMixin, nn.Module):
    sigma = create_hyperparameter_property('sigma')

    def __init__(self, sigma=1.0):
        super().__init__()
        self.register_hyperparameter('sigma', sigma)

m = MyModule(sigma=2.5)
m.sigma  # Uses the property
# 2.5
m.sigma = 3.0  # Sets via property