Source code for torchref.kinetic.kinetics

import torch
from torch.nn import Module as nnModule
from torch.nn import Parameter
from torchref.utils.device_mixin import DeviceMixin
from torchref.config import get_float_dtype
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Union
import numpy as np


[docs] class KineticModel(DeviceMixin, nnModule): """ Configurable PyTorch module for fitting kinetic behavior. Supports arbitrary kinetic schemes defined by relational strings: - "A->B,B->C" (sequential) - "A->B,B->A,B->C" (with back reactions) - "A->B,A->C,B->D,C->D" (parallel pathways) - "A->B,B->C,D" (D is non-reactive state) Each transition has TWO parameters: - Reactivity constant k (rate) - Reaction efficiency η (0 to 1, controls maximum conversion) States can have baseline occupancy offsets (default: 0, not refined). The initial transfer is driven by photoabsorption with quasi-instant conversion around zero, with spread accounted for by an instrument function. Parameters ---------- flow_chart : str Relational string describing the kinetic scheme using comma-separated transitions. Standalone states (non-reactive) can be included without transitions. Example: "A->B,B->C" or "A->B,B->A,B->C,D" (D is non-reactive) timepoints : torch.Tensor or array-like Time points at which to evaluate the kinetics rate_constants : dict or list, optional Initial rate constants. Can be: - Dict mapping "A->B" to float value - List of floats (same order as transitions in flow_chart) - None (random initialization) efficiencies : dict or list, optional Initial reaction efficiencies (0-1). Same format as rate_constants. Default: all 1.0 (100% efficient) instrument_function : str, optional Type of instrument response function. Options: 'gaussian', 'none' Default: 'gaussian' instrument_width : float, optional Width parameter for the instrument function (e.g., sigma for gaussian) Default: 0.1 initial_state : str, optional Which state starts with population 1. Default: first state in flow chart light_activated : bool, optional If True, treats this as a light-activated reaction where the initial photoexcitation can only happen once. Products returning to the initial state become inactive (A*) and cannot undergo photoactivation again. Default: False verbose : int, optional Verbosity level. Default: 1 """
[docs] def __init__( self, flow_chart: str, timepoints, rate_constants: Optional[Union[Dict[str, float], List[float]]] = None, efficiencies: Optional[Union[Dict[str, float], List[float]]] = None, instrument_function: str = 'gaussian', instrument_width: float = 10, initial_state: Optional[str] = None, light_activated: bool = False, activation_level: Optional[float] = 0.5, verbose: int = 1, ): super(KineticModel, self).__init__() self.flow_chart = flow_chart self.verbose = verbose # Convert timepoints to tensor if not isinstance(timepoints, torch.Tensor): timepoints = torch.tensor(timepoints, dtype=get_float_dtype()) self.register_buffer('timepoints', timepoints) # Parse flow chart to extract states and transitions self.states, self.transitions = self._parse_flow_chart(flow_chart) self.n_states = len(self.states) self.state_to_idx = {state: idx for idx, state in enumerate(self.states)} self.n_transitions = len(self.transitions) # Handle light-activated reactions self.light_activated = light_activated if light_activated: # Identify initial state if initial_state is None: initial_state = self.states[0] # Create an inactive version of the initial state (e.g., A -> A*) inactive_state = initial_state + '*' # Add A* as a new state self.states.append(inactive_state) self.n_states += 1 self.state_to_idx[inactive_state] = len(self.states) - 1 # Modify transitions: anything that returns to initial_state now goes to inactive_state modified_transitions = [] self._light_activated_remapping = {} # Track redirected transitions for from_state, to_state in self.transitions: if to_state == initial_state and from_state != initial_state: # Redirect back-reactions to A* instead of A modified_transitions.append((from_state, inactive_state)) # Remember the mapping for rate constant initialization old_key = f"{from_state}->{initial_state}" new_key = f"{from_state}->{inactive_state}" self._light_activated_remapping[new_key] = old_key else: modified_transitions.append((from_state, to_state)) self.transitions = modified_transitions self.n_transitions = len(self.transitions) if self.verbose: print(f"Light-activated mode: {initial_state} products → {inactive_state} (cannot re-photoactivate)") else: self._light_activated_remapping = {} if self.verbose: print(f"Identified {self.n_states} states: {self.states}") print(f"Identified {self.n_transitions} transitions: {self.transitions}") # Store instrument width for smart initialization self._instrument_width = instrument_width # Initialize rate constants (k) as learnable parameters with smart defaults init_log_k = self._initialize_rate_constants(rate_constants, instrument_width, timepoints) self.log_rate_constants = Parameter(init_log_k) # Efficiencies are frozen at 1.0 (not refinable). # They are degenerate with rate constants for sequential models. self.register_buffer( 'efficiencies', torch.ones(self.n_transitions), ) # Initial population if initial_state is None: initial_state = self.states[0] self.initial_state = initial_state initial_populations = torch.zeros(self.n_states) initial_populations[self.state_to_idx[initial_state]] = 1.0 self.register_buffer('initial_populations', initial_populations) # Instrument function parameters (now refinable) self.instrument_function = instrument_function if instrument_function == 'gaussian': # Store log of width to ensure positivity (refinable) self.log_instrument_width = Parameter( torch.tensor(np.log(instrument_width), dtype=get_float_dtype()) ) elif instrument_function == 'none': self.log_instrument_width = None else: raise ValueError(f"Unknown instrument function: {instrument_function}") # Baseline occupancy offsets (default: all zeros, not refined) # These are constant offsets added to the populations # Smart initialization: initial state (A) has 50% baseline (unreactive fraction) self.baseline_occupancies = torch.zeros(self.n_states) initial_state_idx = self.state_to_idx[initial_state] self.baseline_occupancies[initial_state_idx] = 1-activation_level self.register_buffer('_baseline_occupancies', self.baseline_occupancies) self._baseline_refinable = {} # Track which baselines are refinable if self.verbose: print(f"Baseline initialization: State {initial_state} = {1-activation_level} ({activation_level*100}% reactive)")
def _initialize_parameter( self, values: Optional[Union[Dict[str, float], List[float]]], default_value: float = 1.0, transform: str = 'log' ) -> torch.Tensor: """ Initialize a parameter from various input formats. Parameters ---------- values : dict, list, or None Initial values default_value : float Default value if values is None transform : str 'log' for log-transformation, 'none' for no transformation Returns ------- tensor : torch.Tensor Initialized parameter tensor """ init_values = torch.ones(self.n_transitions) * default_value if values is not None: if isinstance(values, dict): # Dictionary mapping "A->B" to value for idx, (from_state, to_state) in enumerate(self.transitions): key = f"{from_state}->{to_state}" if key in values: init_values[idx] = values[key] elif key in self._light_activated_remapping: # Check if this transition was redirected (e.g., O->A* was O->A) original_key = self._light_activated_remapping[key] if original_key in values: init_values[idx] = values[original_key] elif isinstance(values, (list, tuple)): # List of values in order if len(values) != self.n_transitions: raise ValueError(f"Expected {self.n_transitions} values, got {len(values)}") init_values = torch.tensor(values, dtype=get_float_dtype()) else: raise ValueError("values must be dict, list, or None") # Apply transformation if transform == 'log': return torch.log(init_values) else: return init_values def _initialize_rate_constants( self, rate_constants: Optional[Union[Dict[str, float], List[float]]], instrument_width: float, timepoints: torch.Tensor ) -> torch.Tensor: """ Initialize rate constants with smart defaults based on observability constraints. Rules: 1. First transition (photoabsorption): quasi-instant, limited by instrument function τ_1 = σ/3, so k_1 = 3/σ 2. For observable states: 2*k_in ≈ k_out (state reaches ~50% occupancy) 3. Scale rates based on timeframe to ensure observability Parameters ---------- rate_constants : dict, list, or None User-provided rate constants (override smart defaults) instrument_width : float Instrument function width (σ) timepoints : torch.Tensor Time points for the experiment Returns ------- log_k : torch.Tensor Log-transformed rate constants """ if rate_constants is not None: # User provided values - use the standard initialization return self._initialize_parameter(rate_constants, default_value=1.0, transform='log') # Smart initialization based on observability init_k = torch.ones(self.n_transitions) # Determine timeframe t_max = timepoints.max().item() t_min = timepoints[timepoints > 0].min().item() if (timepoints > 0).any() else 1e-3 time_range = t_max - t_min if time_range < 1e-6: # Handle case of single or very closely spaced timepoints time_range = max(t_max, 1.0) # Build a state connectivity map # For each state, track: incoming rates, outgoing rates state_in_indices = {state: [] for state in self.states} state_out_indices = {state: [] for state in self.states} for idx, (from_state, to_state) in enumerate(self.transitions): state_out_indices[from_state].append(idx) state_in_indices[to_state].append(idx) # Identify the first transition (from initial state) initial_state = self.states[0] # Will be set properly later first_transition_indices = state_out_indices[initial_state] # Initialize first transition(s): quasi-instant, limited by instrument function # τ = σ/3, so k = 3/σ if instrument_width > 0: k_first = 3.0 / instrument_width else: k_first = 10.0 # Fallback if no instrument function for idx in first_transition_indices: init_k[idx] = k_first if self.verbose > 1: from_s, to_s = self.transitions[idx] print(f" First transition {from_s}->{to_s}: k = {k_first:.3f} (τ = {1/k_first:.3f})") # For remaining transitions: apply observability constraint # Work through the chain, ensuring 2*k_in ≈ k_out processed = set(first_transition_indices) # Process states in order of connectivity for iteration in range(self.n_transitions): made_progress = False for state in self.states: in_indices = state_in_indices[state] out_indices = state_out_indices[state] # Skip if no outgoing transitions if not out_indices: continue # Check if we have incoming rates already set incoming_set = [idx for idx in in_indices if idx in processed] outgoing_unset = [idx for idx in out_indices if idx not in processed] if incoming_set and outgoing_unset: # Calculate average incoming rate avg_k_in = torch.mean(init_k[incoming_set]).item() # Observability: 2*k_in ≈ k_out for state to reach ~50% occupancy # This ensures the state is observable k_out = avg_k_in / 3.0 # Also consider timeframe - states should be observable within the time range # Ensure τ_out is within the observable window tau_out = 1.0 / k_out if tau_out > time_range: # Too slow, speed it up to be observable k_out = 2.0 / time_range elif tau_out < t_min: # Too fast, slow it down k_out = 1.0 / t_min for idx in outgoing_unset: init_k[idx] = k_out processed.add(idx) made_progress = True if self.verbose > 1: from_s, to_s = self.transitions[idx] print(f" Transition {from_s}->{to_s}: k = {k_out:.3f} (τ = {1/k_out:.3f})") if not made_progress: break # Handle any remaining unprocessed transitions (disconnected or cyclic) for idx in range(self.n_transitions): if idx not in processed: # Use time-range based default k_default = 1.0 / (time_range / 3.0) # Observable in middle third of range init_k[idx] = k_default if self.verbose > 1: from_s, to_s = self.transitions[idx] print(f" Transition {from_s}->{to_s}: k = {k_default:.3f} (τ = {1/k_default:.3f}) [default]") if self.verbose: print(f"Smart initialization:") print(f" Time range: {t_min:.3f} to {t_max:.3f} (Δt = {time_range:.3f})") print(f" Instrument width σ = {instrument_width:.3f}") print(f" First transition k = {k_first:.3f} (τ = {1/k_first:.3f})") return torch.log(init_k) def _parse_flow_chart(self, flow_chart: str) -> Tuple[List[str], List[Tuple[str, str]]]: """ Parse relational flow chart string to extract states and transitions. New format: "A->B,B->C,C->D,C->A" Comma-separated list of transitions. Standalone states (non-reactive) can be included: "A->B,B->C,D" Parameters ---------- flow_chart : str Flow chart string like "A->B,B->C,C->D" or "A->B,B->C,D" where D is a non-reactive state Returns ------- states : List[str] Ordered list of unique states transitions : List[Tuple[str, str]] List of (from_state, to_state) tuples """ # Split by comma to get individual transitions or standalone states transition_strings = [t.strip() for t in flow_chart.split(',')] states_set = set() transitions = [] for trans_str in transition_strings: if '->' in trans_str: # It's a transition parts = trans_str.split('->') if len(parts) != 2: raise ValueError(f"Invalid transition format: '{trans_str}'. Expected exactly one '->'") from_state = parts[0].strip() to_state = parts[1].strip() if not from_state or not to_state: raise ValueError(f"Empty state name in transition: '{trans_str}'") states_set.add(from_state) states_set.add(to_state) transitions.append((from_state, to_state)) else: # It's a standalone (non-reactive) state state_name = trans_str.strip() if not state_name: raise ValueError("Empty state name in flow chart") states_set.add(state_name) # Sort states to ensure consistent ordering states = sorted(states_set) return states, transitions def _build_rate_matrix(self, rate_constants: torch.Tensor, efficiencies: torch.Tensor) -> torch.Tensor: """ Build rate matrix K from rate constants and efficiencies. The rate matrix K is defined such that: dP/dt = K @ P where P is the population vector. Effective rate = k * η (rate constant * efficiency) K[i, j] = effective rate from state j to state i (for i != j) K[i, i] = -sum of rates leaving state i Parameters ---------- rate_constants : torch.Tensor Rate constants k for each transition (must be positive) efficiencies : torch.Tensor Reaction efficiencies η for each transition (0 to 1) Returns ------- K : torch.Tensor Rate matrix of shape (n_states, n_states) """ K = torch.zeros(self.n_states, self.n_states, device=rate_constants.device) # Effective rates = k * η effective_rates = rate_constants * efficiencies # Fill off-diagonal elements for rate_idx, (from_state, to_state) in enumerate(self.transitions): from_idx = self.state_to_idx[from_state] to_idx = self.state_to_idx[to_state] # K[to, from] = effective rate (gain to 'to' from 'from') K[to_idx, from_idx] += effective_rates[rate_idx] # Fill diagonal elements (conservation of probability) for i in range(self.n_states): K[i, i] = -torch.sum(K[:, i]) return K def _solve_kinetics(self, rate_matrix: torch.Tensor) -> torch.Tensor: """ Solve kinetic equations using matrix exponential. P(t) = exp(K * t) @ P(0) for t >= 0 P(t) = P(0) for t < 0 torch.matrix_exp uses Padé approximation with scaling-and-squaring, which handles large ||K*t|| safely. No element-wise clipping is needed (clipping would destroy the row-sum-to-zero structure of the rate matrix). Parameters ---------- rate_matrix : torch.Tensor Rate matrix K of shape (n_states, n_states) Returns ------- populations : torch.Tensor Population of each state at each timepoint Shape: (n_timepoints, n_states) """ populations = [] orig_dtype = rate_matrix.dtype # Use float64 for matrix exponential to maintain precision # when ||K*t|| is large (e.g. fast rates × long times). K64 = rate_matrix.double() P0_64 = self.initial_populations.double() for t in self.timepoints: t_val = t.item() if torch.is_tensor(t) else t if t_val < 0: P_t = P0_64.clone() else: Kt = K64 * t_val exp_Kt = torch.matrix_exp(Kt) P_t = exp_Kt @ P0_64 populations.append(P_t) populations = torch.stack(populations, dim=0).to(orig_dtype) return populations def _apply_instrument_function(self, populations: torch.Tensor) -> torch.Tensor: """ Apply instrument response function to account for time resolution. Performs convolution in real time space using a kernel matrix that accounts for the actual time differences between measurement points. This is essential for non-uniformly spaced time grids (e.g. logarithmic). For Gaussian IRF: S(t_i) = Σ_j P(t_j) * G(t_i - t_j) * w_j / Σ_j G(t_i - t_j) * w_j where G(Δt) = exp(-Δt²/(2σ²)) and w_j are trapezoidal quadrature weights. Parameters ---------- populations : torch.Tensor Raw populations, shape (n_timepoints, n_states) Returns ------- populations_conv : torch.Tensor Populations after convolution, shape (n_timepoints, n_states) """ if self.instrument_function == 'none': return populations elif self.instrument_function == 'gaussian': sigma = torch.exp(self.log_instrument_width) t = self.timepoints # (n_timepoints,) # Time difference matrix: dt[i,j] = t_i - t_j dt = t.unsqueeze(0) - t.unsqueeze(1) # (n_t, n_t) # Gaussian kernel evaluated at actual time differences K = torch.exp(-0.5 * (dt / sigma) ** 2) # (n_t, n_t) # Trapezoidal quadrature weights for non-uniform grid n_t = len(t) w = torch.zeros(n_t, device=t.device, dtype=t.dtype) if n_t > 1: w[0] = (t[1] - t[0]) / 2 w[-1] = (t[-1] - t[-2]) / 2 if n_t > 2: w[1:-1] = (t[2:] - t[:-2]) / 2 else: w[0] = 1.0 # Weight kernel by quadrature weights K = K * w.unsqueeze(0) # (n_t, n_t) # Normalize each row so weights sum to 1 K = K / K.sum(dim=1, keepdim=True) # Apply convolution: populations_conv = K @ populations # Ensure matching dtypes (timepoints may differ from populations) populations_conv = torch.matmul(K.to(populations.dtype), populations) return populations_conv else: raise ValueError(f"Unknown instrument function: {self.instrument_function}")
[docs] def forward(self) -> torch.Tensor: """ Forward pass: compute populations at all timepoints. Returns ------- populations : torch.Tensor Population of each state at each timepoint Shape: (n_timepoints, n_states) """ # Get rate constants (ensure positivity via exp) rate_constants = torch.exp(self.log_rate_constants) # Build rate matrix (efficiencies frozen at 1.0) rate_matrix = self._build_rate_matrix(rate_constants, self.efficiencies) # Solve kinetics (dynamic populations, sum to 1) populations = self._solve_kinetics(rate_matrix) # Apply instrument function populations = self._apply_instrument_function(populations) # Rescale and add baseline occupancies to maintain total population = 1 if hasattr(self, '_baseline_occupancies'): # Build baseline tensor that includes refinable parameters in the # computation graph (so gradients flow through them). baseline = self._baseline_occupancies.clone() for state, param_name in self._baseline_refinable.items(): idx = self.state_to_idx[state] param = getattr(self, param_name) baseline[idx] = torch.sigmoid(param) # Calculate total baseline occupancy total_baseline = baseline.sum() # Rescale dynamic populations to (1 - total_baseline) # This ensures that dynamic + baseline = 1 reactive_fraction = 1.0 - total_baseline populations = populations * reactive_fraction # Add baseline occupancies populations = populations + baseline.unsqueeze(0) return populations
[docs] def set_baseline( self, state: str, occupancy: float, refinable: bool = False ): """ Set baseline occupancy offset for a state. Baseline occupancies are constant offsets added to the population of a state. This is useful for non-reactive background states. Parameters ---------- state : str Name of the state occupancy : float Baseline occupancy value (offset) refinable : bool, optional If True, this baseline becomes a refinable parameter. If False (default), it remains constant. Examples -------- >>> model.set_baseline('D', 0.1, refinable=False) # Constant 10% background >>> model.set_baseline('E', 0.05, refinable=True) # Refinable baseline """ if state not in self.state_to_idx: raise ValueError(f"State '{state}' not found in model. Available states: {self.states}") idx = self.state_to_idx[state] if refinable: # Create a refinable parameter if it doesn't exist param_name = f'_baseline_{state}' if not hasattr(self, param_name): # Use logit transformation to keep baseline in (0, 1) # baseline = sigmoid(logit_baseline) init_val = torch.clamp(torch.tensor(occupancy, dtype=get_float_dtype()), 0.01, 0.99) logit_val = torch.log(init_val / (1 - init_val)) setattr(self, param_name, Parameter(logit_val)) self._baseline_refinable[state] = param_name else: # Update existing parameter param = getattr(self, param_name) init_val = torch.clamp(torch.tensor(occupancy, dtype=get_float_dtype()), 0.01, 0.99) with torch.no_grad(): param.data = torch.log(init_val / (1 - init_val)) else: # Set as constant (non-refinable) self._baseline_occupancies[idx] = occupancy # Remove from refinable dict if it was there if state in self._baseline_refinable: param_name = self._baseline_refinable.pop(state) if hasattr(self, param_name): delattr(self, param_name)
[docs] def get_baselines(self) -> Dict[str, float]: """ Get current baseline occupancies for all states. Returns ------- baselines : Dict[str, float] Dictionary mapping state names to baseline occupancies """ baselines = {} for state, idx in self.state_to_idx.items(): # Check if it's refinable if state in self._baseline_refinable: param_name = self._baseline_refinable[state] param = getattr(self, param_name) baselines[state] = float(torch.sigmoid(param).detach().cpu().numpy()) else: # Use the constant value baselines[state] = float(self._baseline_occupancies[idx].cpu().numpy()) return baselines
def _update_baselines_from_refinable(self): """ Update baseline occupancies tensor from refinable parameters. Called internally during forward pass if needed. """ for state, param_name in self._baseline_refinable.items(): idx = self.state_to_idx[state] param = getattr(self, param_name) self._baseline_occupancies[idx] = torch.sigmoid(param)
[docs] def get_rate_constants(self) -> Dict[str, float]: """ Get current rate constants (k) as a dictionary. Returns ------- rate_dict : Dict[str, float] Dictionary mapping transition strings to rate constants """ rate_constants = torch.exp(self.log_rate_constants).detach().cpu().numpy() rate_dict = {} for idx, (from_state, to_state) in enumerate(self.transitions): key = f"{from_state}->{to_state}" rate_dict[key] = float(rate_constants[idx]) return rate_dict
[docs] def set_rate_constant(self, transition: str, value: float): """ Set rate constant for a specific transition. Parameters ---------- transition : str Transition string in the format "A->B" value : float New rate constant value (must be positive) """ if '->' not in transition: raise ValueError(f"Invalid transition format: '{transition}'. Expected 'A->B'") parts = transition.split('->') if len(parts) != 2: raise ValueError(f"Invalid transition format: '{transition}'. Expected exactly one '->'") from_state = parts[0].strip() to_state = parts[1].strip() # Find index of the transition for idx, (f_state, t_state) in enumerate(self.transitions): if f_state == from_state and t_state == to_state: with torch.no_grad(): self.log_rate_constants[idx] = torch.log(torch.tensor(value, dtype=get_float_dtype())) return raise ValueError(f"Transition '{transition}' not found in model.")
[docs] def get_efficiencies(self) -> Dict[str, float]: """ Get current reaction efficiencies (η) as a dictionary. Returns ------- eff_dict : Dict[str, float] Dictionary mapping transition strings to efficiencies (0-1) """ eff_dict = {} for idx, (from_state, to_state) in enumerate(self.transitions): key = f"{from_state}->{to_state}" eff_dict[key] = float(self.efficiencies[idx].item()) return eff_dict
[docs] def get_effective_rates(self) -> Dict[str, float]: """ Get effective rates (k * η) as a dictionary. Returns ------- eff_rate_dict : Dict[str, float] Dictionary mapping transition strings to effective rates """ rate_dict = self.get_rate_constants() eff_dict = self.get_efficiencies() eff_rate_dict = {key: rate_dict[key] * eff_dict[key] for key in rate_dict} return eff_rate_dict
[docs] def get_time_constants(self) -> Dict[str, float]: """ Get time constants (1/k_eff) for each transition. Returns ------- time_dict : Dict[str, float] Dictionary mapping transition strings to time constants """ eff_rate_dict = self.get_effective_rates() time_dict = {key: 1.0/rate if rate > 1e-10 else float('inf') for key, rate in eff_rate_dict.items()} return time_dict
[docs] def parameters(self) -> Dict[str, torch.Tensor]: """ Get all flexible (learnable) parameters as a dictionary. Returns ------- params : Dict[str, torch.Tensor] Dictionary mapping parameter names to their tensors: - 'log_rate_constants': log-transformed rate constants - 'log_instrument_width': log-transformed instrument width (if refinable) - 'baseline_{state}': refinable baseline for specific states (if any) Examples -------- >>> model = KineticModel(...) >>> params = model.parameters() >>> print(params.keys()) >>> # Use with optimizer: optimizer = torch.optim.Adam(params.values(), lr=0.01) """ params = { 'log_rate_constants': self.log_rate_constants, } if self.log_instrument_width is not None: params['log_instrument_width'] = self.log_instrument_width # Add refinable baselines if hasattr(self, '_baseline_refinable'): for state in sorted(self._baseline_refinable.keys()): param_name = self._baseline_refinable[state] params[f'baseline_{state}'] = getattr(self, param_name) return params
[docs] def print_parameters(self): """Print current model parameters.""" print("\n" + "="*50) print(f"Kinetic Model: {self.flow_chart}") print("="*50) print("\nRate Constants (k):") for key, val in self.get_rate_constants().items(): print(f" {key}: {val:.6f}") print("\nEfficiencies (η):") for key, val in self.get_efficiencies().items(): print(f" {key}: {val:.4f}") print("\nEffective Rates (k*η):") for key, val in self.get_effective_rates().items(): print(f" {key}: {val:.6f}") print("\nTime Constants (1/k_eff):") for key, val in self.get_time_constants().items(): if val == float('inf'): print(f" {key}: ∞") else: print(f" {key}: {val:.6f}") if self.instrument_function == 'gaussian': sigma = torch.exp(self.log_instrument_width).item() print(f"\nInstrument Function: Gaussian (σ = {sigma:.6f})") # Print baselines if any are non-zero baselines = self.get_baselines() if any(val != 0.0 for val in baselines.values()): print("\nBaseline Occupancies:") for state, val in baselines.items(): if val != 0.0: refinable_marker = " (refinable)" if state in self._baseline_refinable else "" print(f" {state}: {val:.6f}{refinable_marker}") print("="*50 + "\n")
[docs] def plot_occupancies( self, outpath: str, times: Optional[torch.Tensor] = None, log: bool = False, figsize: Tuple[int, int] = (10, 6), dpi: int = 150, title: Optional[str] = None ): """ Plot state occupancies over time and save to file. Parameters ---------- outpath : str Path to save the plot (e.g., 'kinetics.png') log : bool, optional If True, use log scale for x-axis. Default: False figsize : Tuple[int, int], optional Figure size (width, height). Default: (10, 6) dpi : int, optional DPI for saving figure. Default: 150 title : str, optional Custom title for the plot. If None, uses flow chart string """ # Compute populations if times is not None: # Temporarily override timepoints original_timepoints = self.timepoints self.timepoints = times with torch.no_grad(): populations = self().detach().cpu().numpy() t = self.timepoints.cpu().numpy() # Create figure plt.figure(figsize=figsize) # Plot each state (combine A and A* if in light-activated mode) plotted_states = [] plotted_populations = [] if self.light_activated: # Find initial state and its inactive version initial_state = self.states[0] inactive_state = initial_state + '*' # Combine A and A* populations if inactive_state in self.states: idx_active = self.state_to_idx[initial_state] idx_inactive = self.state_to_idx[inactive_state] combined_pop = populations[:, idx_active] + populations[:, idx_inactive] plotted_states.append(initial_state) plotted_populations.append(combined_pop) # Add other states (excluding A and A*) for i, state in enumerate(self.states): if state not in [initial_state, inactive_state]: plotted_states.append(state) plotted_populations.append(populations[:, i]) else: # Fallback if something went wrong plotted_states = self.states plotted_populations = [populations[:, i] for i in range(len(self.states))] else: # Normal mode: plot all states separately plotted_states = self.states plotted_populations = [populations[:, i] for i in range(len(self.states))] # Plot for state, pop in zip(plotted_states, plotted_populations): plt.plot(t, pop, label=f'State {state}', linewidth=2) # Formatting plt.xlabel('Time', fontsize=12) plt.ylabel('Occupancy', fontsize=12) if title is None: title = f'State Occupancies: {self.flow_chart}' plt.title(title, fontsize=14) plt.legend(fontsize=10, loc='best') plt.grid(True, alpha=0.3) plt.axvline(x=0, color='k', linestyle='--', alpha=0.3, linewidth=1) # Set log scale if requested if log: # Only use log scale for positive time values if (t > 0).any(): plt.xscale('log') # Adjust x-limits to show only positive times pos_times = t[t > 0] if len(pos_times) > 0: plt.xlim(left=pos_times.min()) plt.tight_layout() # Save figure plt.savefig(outpath, dpi=dpi, bbox_inches='tight') print(f"Plot saved to: {outpath}") plt.close()
[docs] def visualize(self, outpath: str, **kwargs): """ Alias for plot_occupancies for convenience. Parameters ---------- outpath : str Path to save the plot **kwargs Additional arguments passed to plot_occupancies """ self.plot_occupancies(outpath, **kwargs)