Source code for torchref.refinement.targets.adp.entropy

import torch
from typing import TYPE_CHECKING, Dict

from torchref.utils.stats import (
    VERBOSITY_DEBUG,
    VERBOSITY_DETAILED,
    VERBOSITY_STANDARD,
    StatEntry,
    stat,
)

from .base import ADPTarget

if TYPE_CHECKING:
    from torchref.model.model import Model


[docs] class ADPEntropyTarget(ADPTarget): """ ADP Entropy regularization target. Uses the model's existing adp_kl_divergence_loss or similar. """ name: str = "adp/KL"
[docs] def __init__(self, model: "Model" = None, verbose: int = 0): super().__init__(model, verbose, target_value=0.5, sigma=0.5)
[docs] def forward(self) -> torch.Tensor: return self.model.adp_kl_divergence_loss()
[docs] def stats(self) -> Dict[str, any]: """Get KL divergence statistics.""" adp = self.model.adp().detach() log_adp = torch.log(adp.clamp(min=1e-3)) loss = self.forward() return { "loss": stat(loss.item(), VERBOSITY_STANDARD), "n_atoms": stat(len(adp), VERBOSITY_DEBUG), "mean_adp": stat(adp.mean().item(), VERBOSITY_DETAILED), "std_adp": stat(adp.std().item(), VERBOSITY_DETAILED), "min_adp": stat(adp.min().item(), VERBOSITY_DETAILED), "max_adp": stat(adp.max().item(), VERBOSITY_DETAILED), "mean_log_adp": stat(log_adp.mean().item(), VERBOSITY_DEBUG), "std_log_adp": stat(log_adp.std().item(), VERBOSITY_DEBUG), }