Source code for torchref.refinement.logger

"""
Refinement Logger - Separate logging concerns from refinement logic.

Provides verbosity-aware statistics recording and comparison for refinement workflows.
Integrates with LossState to capture and display refinement progress.
"""

import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Pattern

import torch

from torchref.refinement.loss_state import LossState
from torchref.utils.stats import (
    VERBOSITY_STANDARD,
    filter_stats,
)


[docs] @dataclass class Logger: """ Refinement logging with verbosity-aware stat reporting. Integrates with LossState to record, compare, and display refinement stats. Supports regex-based filtering of target names. Parameters ---------- state : LossState The loss state to monitor. verbose : int Verbosity level (0=essential, 1=standard, 2=detailed, 3=debug). pattern : str Regex pattern for filtering target names. Default ".*" matches all. Examples: "xray.*" for X-ray targets only, "geometry/bond" for specific target. Examples -------- Basic usage:: from torchref.refinement import LossState, Logger state = LossState(device=device) state.register_target("xray/work", xray_target) state.register_target("geometry/bond", bond_target) logger = Logger(state, verbose=1) # Record before refinement logger.record(label="before_xyz") # ... run refinement ... # Record after and compare logger.record(label="after_xyz") logger.compare(title="XYZ Refinement") Filtering by pattern:: # Show only X-ray targets logger.current(pattern="xray.*") # Create a logger that only tracks geometry geom_logger = Logger(state, pattern="geometry.*") """ state: LossState verbose: int = VERBOSITY_STANDARD pattern: str = ".*" _records: List[Dict[str, Any]] = field(default_factory=list) _labels: Dict[str, int] = field(default_factory=dict) _compiled_pattern: Pattern = field(init=False, repr=False) def __post_init__(self): self._compiled_pattern = re.compile(self.pattern)
[docs] def record(self, label: str = None) -> Dict[str, Any]: """ Record current refinement state. Gathers stats from all targets in LossState, stores in history. Uses the instance's pattern filter. Parameters ---------- label : str, optional Label for this record (e.g., "before_xyz", "after_xyz"). Returns ------- dict The recorded stats dictionary. """ stats = {} with torch.no_grad(): # Aggregate to ensure losses are computed self.state.aggregate() # Collect stats from each target (filtered by pattern) for name, target in self.state.targets.items(): if not self._matches_pattern(name): continue if hasattr(target, "stats"): target_stats = target.stats() if target_stats: stats[name] = target_stats # Add loss values (filtered by pattern) stats["losses"] = { name: loss.item() for name, loss in self.state._losses.items() if self._matches_pattern(name) } # Add weights (filtered by pattern) stats["weights"] = { name: weight for name, weight in self.state.weights.items() if self._matches_pattern(name) } # Add group totals stats["group_totals"] = self.state.get_group_totals() # Store with optional label record_entry = {"stats": stats, "label": label} self._records.append(record_entry) if label: self._labels[label] = len(self._records) - 1 return stats
[docs] def compare( self, label_before: str = None, label_after: str = None, pattern: str = None, title: str = "Refinement Comparison", ) -> None: """ Print comparison between two recorded states. If labels not provided, compares last two records. Parameters ---------- label_before : str, optional Label of "before" state. Default: second-to-last record. label_after : str, optional Label of "after" state. Default: last record. pattern : str, optional Regex to filter which targets to display. Default: use instance pattern. title : str, optional Title for the comparison output. Default: "Refinement Comparison". """ # Get before record if label_before and label_before in self._labels: before = self._records[self._labels[label_before]]["stats"] elif len(self._records) >= 2: before = self._records[-2]["stats"] else: before = {} # Get after record if label_after and label_after in self._labels: after = self._records[self._labels[label_after]]["stats"] elif self._records: after = self._records[-1]["stats"] else: after = {} # Filter by verbosity before_filtered = filter_stats(before, self.verbose) after_filtered = filter_stats(after, self.verbose) # Apply regex pattern filter before_filtered = self._filter_by_pattern(before_filtered, pattern) after_filtered = self._filter_by_pattern(after_filtered, pattern) # Print formatted comparison table self._print_comparison(before_filtered, after_filtered, title)
[docs] def current(self, pattern: str = None, title: str = "Current State") -> None: """ Print the current refinement state. Uses latest recorded state, or records new one if none exist. Parameters ---------- pattern : str, optional Regex to filter which targets to display. Default: use instance pattern. title : str, optional Title for the output. Default: "Current State". """ if not self._records: self.record() stats = self._records[-1]["stats"] filtered = filter_stats(stats, self.verbose) filtered = self._filter_by_pattern(filtered, pattern) self._print_current(filtered, title)
[docs] def get_record(self, label: str) -> Optional[Dict[str, Any]]: """ Get a specific recorded state by label. Parameters ---------- label : str The label to look up. Returns ------- dict or None The recorded stats dictionary, or None if label not found. """ if label in self._labels: return self._records[self._labels[label]]["stats"] return None
[docs] def clear(self) -> None: """Clear all recorded history.""" self._records.clear() self._labels.clear()
@property def history(self) -> List[Dict[str, Any]]: """Access full recording history.""" return self._records # ========================================================================= # Private Methods # ========================================================================= def _matches_pattern(self, name: str, pattern: str = None) -> bool: """ Check if target name matches the filter pattern. Parameters ---------- name : str The name to check. pattern : str, optional Pattern to use. If None, uses instance's compiled pattern. Returns ------- bool True if name matches the pattern. """ if pattern is None: return self._compiled_pattern.search(name) is not None return re.search(pattern, name) is not None def _filter_by_pattern(self, stats: Dict, pattern: str = None) -> Dict: """ Filter stats dict by regex pattern on keys. Parameters ---------- stats : dict Stats dictionary to filter. pattern : str, optional Pattern to use. If None, uses instance pattern. Returns ------- dict Filtered dictionary. """ filtered = {} for key, val in stats.items(): if self._matches_pattern(key, pattern): filtered[key] = val elif isinstance(val, dict): # Recurse into nested dicts nested = self._filter_by_pattern(val, pattern) if nested: filtered[key] = nested return filtered def _group_by_hierarchy(self, keys_and_values: Dict) -> Dict[str, Dict]: """ Group flat keys by hierarchical prefix. Parameters ---------- keys_and_values : dict Flat dictionary with potentially hierarchical keys. Returns ------- dict Grouped dictionary: {"xray": {"work": 3.2, "test": 3.4}, ...} Examples -------- Input: {"xray/work": 3.2, "xray/test": 3.4, "geometry/bond": 0.02} Output: {"xray": {"work": 3.2, "test": 3.4}, "geometry": {"bond": 0.02}} """ groups = {} for key, val in keys_and_values.items(): if "/" in key: group, component = key.split("/", 1) else: group, component = "other", key if group not in groups: groups[group] = {} groups[group][component] = val return groups def _print_comparison(self, before: Dict, after: Dict, title: str) -> None: """ Format and print before/after comparison table. Parameters ---------- before : dict Stats from before state. after : dict Stats from after state. title : str Title for the comparison. """ width = 68 separator = "─" * width print(f"\n{separator}") print(f" {title}") print(separator) # Get losses and weights before_losses = before.get("losses", {}) after_losses = after.get("losses", {}) before_weights = before.get("weights", {}) after_weights = after.get("weights", {}) before_totals = before.get("group_totals", {}) after_totals = after.get("group_totals", {}) # Group losses by hierarchy before_grouped = self._group_by_hierarchy(before_losses) after_grouped = self._group_by_hierarchy(after_losses) # Get all groups all_groups = sorted(set(before_grouped.keys()) | set(after_grouped.keys())) for group in all_groups: # Get group weight weight = after_weights.get(group, before_weights.get(group, 1.0)) print(f"\n {group} (weight: {weight:.4f})") print(f" {'─' * (width - 4)}") print(f" {'Metric':<24} {'Before':>10} {'After':>10} {'Change':>10}") before_group = before_grouped.get(group, {}) after_group = after_grouped.get(group, {}) all_components = sorted(set(before_group.keys()) | set(after_group.keys())) for comp in all_components: bval = before_group.get(comp) aval = after_group.get(comp) self._print_comparison_row(comp, bval, aval) # Print group totals if before_totals or after_totals: print(f"\n Group Totals") print(f" {'─' * (width - 4)}") print(f" {'Group':<24} {'Before':>10} {'After':>10} {'Change':>10}") all_total_groups = sorted( set(before_totals.keys()) | set(after_totals.keys()) ) for group in all_total_groups: bval = before_totals.get(group) aval = after_totals.get(group) self._print_comparison_row(group, bval, aval) print(separator) def _print_current(self, stats: Dict, title: str) -> None: """ Format and print current state. Parameters ---------- stats : dict Stats dictionary. title : str Title for the output. """ width = 68 separator = "─" * width print(f"\n{separator}") print(f" {title}") print(separator) # Get losses, weights, and totals losses = stats.get("losses", {}) weights = stats.get("weights", {}) totals = stats.get("group_totals", {}) # Group losses by hierarchy grouped = self._group_by_hierarchy(losses) for group in sorted(grouped.keys()): # Get group weight weight = weights.get(group, 1.0) print(f"\n {group} (weight: {weight:.4f})") print(f" {'─' * (width - 4)}") for comp, val in sorted(grouped[group].items()): formatted = self._format_value(val) print(f" {comp:<32} {formatted:>10}") # Print group totals if totals: print(f"\n Group Totals") print(f" {'─' * (width - 4)}") for group, val in sorted(totals.items()): formatted = self._format_value(val) print(f" {group:<32} {formatted:>10}") print(separator) def _print_comparison_row( self, label: str, before: float, after: float ) -> None: """ Print a single comparison row. Parameters ---------- label : str Row label. before : float or None Value before. after : float or None Value after. """ bstr = self._format_value(before) if before is not None else "-" astr = self._format_value(after) if after is not None else "-" if before is not None and after is not None: change = after - before cstr = self._format_value(change, show_sign=True) else: cstr = "-" print(f" {label:<24} {bstr:>10} {astr:>10} {cstr:>10}") def _format_value(self, val: float, show_sign: bool = False) -> str: """ Format a numeric value for display. Parameters ---------- val : float Value to format. show_sign : bool If True, show + sign for positive values. Returns ------- str Formatted string. """ if val is None: return "-" # Handle very small values if abs(val) < 0.0001 and val != 0: if show_sign and val > 0: return f"+{val:.2e}" return f"{val:.2e}" # Handle large values elif abs(val) >= 1000: if show_sign and val > 0: return f"+{val:.1f}" return f"{val:.1f}" # Standard precision else: if show_sign and val > 0: return f"+{val:.4f}" return f"{val:.4f}"
__all__ = ["Logger"]