from torch import nn
import torch
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
from torchref.utils.device_mixin import DeviceMixin
[docs]
class occupancy_unrestrained(DeviceMixin, nn.Module):
"""
Unrestrained occupancy model where each state at each timepoint is independent.
This is the most flexible model but may lead to physically unrealistic solutions
since occupancies at different timepoints are not coupled.
Parameters
----------
nstates : int
Number of structural states
time : list or array-like
Time points for the experiment
"""
[docs]
def __init__(self, nstates, time):
super(occupancy_unrestrained, self).__init__()
self.nstates = nstates
self.logits = nn.Parameter(torch.zeros(nstates, len(time)))
self.time = time
[docs]
def forward(self):
occupancies = torch.exp(self.logits)
occupancies = occupancies / occupancies.sum(dim=0, keepdim=True)
return occupancies
[docs]
def plot_occupancies(self, path, log_scale=False, figsize=(10, 6)):
import matplotlib.pyplot as plt
occupancies = self().detach().cpu().numpy()
time_points = np.array(self.time)
plt.figure(figsize=figsize)
for i in range(self.nstates):
plt.plot(time_points, occupancies[i], label=f'State {i}', linewidth=2)
plt.xlabel('Time')
plt.ylabel('Occupancy')
plt.legend()
if log_scale and (time_points > 0).any():
plt.xscale('log')
pos_times = time_points[time_points > 0]
if len(pos_times) > 0:
plt.xlim(left=pos_times.min())
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(path, dpi=150)
plt.close()
from torchref.kinetic.kinetics import KineticModel
[docs]
class occupancies_kinetics(DeviceMixin, nn.Module):
"""
Kinetics-constrained occupancy model.
This model uses a kinetic scheme to constrain occupancies at different timepoints.
Instead of independent parameters for each timepoint, the occupancies are derived
from rate constants, efficiencies, and a kinetic flow chart.
This provides several advantages over unrestrained refinement:
1. Physical constraints: occupancies follow kinetic laws
2. Reduced parameters: n_rates instead of n_states * n_timepoints
3. Extrapolation: can predict occupancies at unmeasured timepoints
4. Interpretability: rate constants have physical meaning
Refinement Considerations
-------------------------
Kinetic refinement has unique challenges compared to normal refinement:
1. **Parameter Scale Separation**: Rate constants can span many orders of magnitude
(e.g., ps to ms). Use log-parameterization and consider different learning rates.
2. **Identifiability**: Some parameters may be correlated:
- Rate and efficiency can compensate for each other
- Back-reactions can create degenerate solutions
Consider using efficiency constraints or regularization.
3. **Local Minima**: The optimization landscape can have multiple minima
corresponding to different kinetic interpretations. Initialize carefully.
4. **Gradient Flow**: Matrix exponentials can have vanishing/exploding gradients.
The implementation clips extreme values for stability.
5. **Regularization**: Consider adding priors on:
- Rate constants (log-normal centered on expected timescales)
- Efficiencies (Beta distribution favoring high efficiency)
- Smoothness of rate changes if doing temperature-dependent fitting
Parameters
----------
flow_chart : str
Kinetic scheme, e.g., "A->B,B->C,C->D" or "A->B,B->A,B->C" (with back-reaction)
time : list or tensor
Time points at which to evaluate occupancies
rate_constants : dict or None, optional
Initial rate constants as {"A->B": value, ...}. If None, uses smart initialization.
efficiencies : dict or None, optional
Initial efficiencies as {"A->B": value, ...}. Default: all 1.0
instrument_width : float, optional
Instrument response function width (Gaussian sigma). Default: 10
light_activated : bool, optional
If True, products returning to ground state become inactive. Default: False
state_mapping : dict or None, optional
Mapping from kinetic states to structural model indices.
E.g., {"A": 0, "B": 1, "C": 2, "D": 3} or {"A": 0, "B": 1, "C": 1, "D": 2}
The latter allows multiple kinetic states to map to the same structure.
If None, assumes sequential mapping (A=0, B=1, ...).
regularization : dict or None, optional
Regularization settings:
- 'rate_prior_weight': weight for log-normal prior on rates
- 'rate_prior_mean': mean of log-rate prior (default: based on time range)
- 'rate_prior_std': std of log-rate prior (default: 2.0, allows ~2 orders of magnitude)
- 'efficiency_prior_weight': weight for efficiency prior (favoring 1.0)
verbose : int, optional
Verbosity level. Default: 1
Examples
--------
>>> # Simple sequential kinetics: A -> B -> C -> D
>>> occ = occupancies_kinetics(
... flow_chart="A->B,B->C,C->D",
... time=torch.linspace(0, 100, 50),
... rate_constants={"A->B": 1.0, "B->C": 0.1, "C->D": 0.01}
... )
>>> occupancies = occ() # Shape: [n_states, n_timepoints]
>>> # With back-reaction
>>> occ = occupancies_kinetics(
... flow_chart="A->B,B->A,B->C",
... time=times,
... light_activated=True # Products returning to A become inactive
... )
>>> # Mapping multiple kinetic states to same structure
>>> occ = occupancies_kinetics(
... flow_chart="A->B,B->C,C->D",
... time=times,
... state_mapping={"A": 0, "B": 1, "C": 1, "D": 0} # B and C share structure
... )
"""
[docs]
def __init__(
self,
flow_chart: str,
time: Union[List, torch.Tensor],
rate_constants: Optional[Dict[str, float]] = None,
efficiencies: Optional[Dict[str, float]] = None,
instrument_function: str = 'none',
instrument_width: float = 10.0,
light_activated: bool = False,
state_mapping: Optional[Dict[str, int]] = None,
regularization: Optional[Dict[str, float]] = None,
verbose: int = 1
):
super(occupancies_kinetics, self).__init__()
self.verbose = verbose
# Convert time to tensor if needed
if not isinstance(time, torch.Tensor):
time = torch.tensor(time, dtype=torch.float32)
self.register_buffer('time', time)
# Initialize the kinetic model
self.kinetics = KineticModel(
flow_chart=flow_chart,
timepoints=time,
rate_constants=rate_constants,
efficiencies=efficiencies,
instrument_function=instrument_function,
instrument_width=instrument_width,
light_activated=light_activated,
verbose=verbose
)
# Setup state mapping (kinetic states to structural model indices)
self._setup_state_mapping(state_mapping)
# Setup regularization
self._setup_regularization(regularization)
if self.verbose:
print(f"\nKinetic occupancy model initialized:")
print(f" Flow chart: {flow_chart}")
print(f" Time points: {len(time)} (range: {time.min():.2f} to {time.max():.2f})")
print(f" Kinetic states: {self.kinetics.n_states}")
print(f" Structural models: {self.nstates}")
print(f" Mapping: {self.state_mapping}")
def _setup_state_mapping(self, state_mapping: Optional[Dict[str, int]]):
"""
Setup mapping from kinetic states to structural model indices.
This allows:
1. Multiple kinetic intermediates to share the same structure
2. Reordering of states if kinetic and structural order differ
"""
kinetic_states = self.kinetics.states
if state_mapping is None:
# Default: sequential mapping (A=0, B=1, C=2, ...)
state_mapping = {state: i for i, state in enumerate(kinetic_states)}
self.state_mapping = state_mapping
# Number of unique structural states
self.nstates = max(state_mapping.values()) + 1
# Create mapping matrix: [n_structural, n_kinetic]
# This sums kinetic populations that map to the same structure
mapping_matrix = torch.zeros(self.nstates, len(kinetic_states))
for kinetic_state, struct_idx in state_mapping.items():
kinetic_idx = self.kinetics.state_to_idx[kinetic_state]
mapping_matrix[struct_idx, kinetic_idx] = 1.0
self.register_buffer('mapping_matrix', mapping_matrix)
def _setup_regularization(self, regularization: Optional[Dict[str, float]]):
"""Setup regularization priors for kinetic parameters."""
self.regularization = regularization or {}
# Default regularization settings
defaults = {
'rate_prior_weight': 0.0, # Disabled by default
'rate_prior_mean': None, # Auto-determine from time range
'rate_prior_std': 2.0, # Allows ~2 orders of magnitude variation
'efficiency_prior_weight': 0.0, # Disabled by default
}
for key, default_val in defaults.items():
if key not in self.regularization:
self.regularization[key] = default_val
# Auto-determine rate prior mean from time range if not specified
if self.regularization['rate_prior_mean'] is None:
t_range = self.time.max() - self.time.min()
if t_range > 0:
# Center prior on rate that gives observable kinetics in the time range
self.regularization['rate_prior_mean'] = float(np.log(3.0 / t_range.item()))
else:
self.regularization['rate_prior_mean'] = 0.0 # k = 1
[docs]
def forward(self) -> torch.Tensor:
"""
Compute occupancies at all timepoints.
Returns
-------
occupancies : torch.Tensor
Occupancy of each structural state at each timepoint.
Shape: [n_structural_states, n_timepoints]
"""
# Get kinetic populations: [n_timepoints, n_kinetic_states]
kinetic_populations = self.kinetics()
# Map to structural states: [n_timepoints, n_structural_states]
structural_populations = torch.matmul(kinetic_populations, self.mapping_matrix.T)
# Transpose to match expected format: [n_structural_states, n_timepoints]
occupancies = structural_populations.T
return occupancies
[docs]
def get_regularization_loss(self) -> torch.Tensor:
"""
Compute regularization loss for kinetic parameters.
This implements prior distributions on the parameters:
- Log-normal prior on rate constants
- Beta-like prior on efficiencies (favoring values near 1)
Returns
-------
reg_loss : torch.Tensor
Regularization loss to be added to the main loss
"""
reg_loss = torch.tensor(0.0, device=self.kinetics.log_rate_constants.device)
# Rate constant prior (log-normal)
rate_weight = self.regularization.get('rate_prior_weight', 0.0)
if rate_weight > 0:
rate_mean = self.regularization['rate_prior_mean']
rate_std = self.regularization['rate_prior_std']
# Log-normal prior: (log(k) - mu)^2 / (2 * sigma^2)
log_rates = self.kinetics.log_rate_constants
rate_prior_loss = torch.sum((log_rates - rate_mean) ** 2) / (2 * rate_std ** 2)
reg_loss = reg_loss + rate_weight * rate_prior_loss
# Efficiency prior — no-op since efficiencies are frozen at 1.0
return reg_loss
[docs]
def get_rate_constants(self) -> Dict[str, float]:
"""Get current rate constants."""
return self.kinetics.get_rate_constants()
[docs]
def get_efficiencies(self) -> Dict[str, float]:
"""Get current efficiencies."""
return self.kinetics.get_efficiencies()
[docs]
def get_time_constants(self) -> Dict[str, float]:
"""Get current time constants (1/k_eff)."""
return self.kinetics.get_time_constants()
[docs]
def set_rate_constant(self, transition: str, value: float):
"""Set a specific rate constant."""
self.kinetics.set_rate_constant(transition, value)
[docs]
def freeze_rates(self, transitions: Optional[List[str]] = None):
"""
Freeze rate constants (exclude from optimization).
Parameters
----------
transitions : list of str or None
Transitions to freeze (e.g., ["A->B"]). If None, freezes all.
"""
if transitions is None:
self.kinetics.log_rate_constants.requires_grad = False
else:
# Create mask for transitions to freeze
# Note: This requires more complex handling with hooks
# For now, just warn that partial freezing needs custom implementation
import warnings
warnings.warn("Partial rate freezing not yet implemented. Use full freeze or manual gradient masking.")
self.kinetics.log_rate_constants.requires_grad = False
[docs]
def unfreeze_rates(self):
"""Unfreeze all rate constants."""
self.kinetics.log_rate_constants.requires_grad = True
[docs]
def freeze_efficiencies(self):
"""No-op: efficiencies are always frozen at 1.0."""
pass
[docs]
def unfreeze_efficiencies(self):
"""No-op: efficiencies are always frozen at 1.0."""
pass
[docs]
def freeze_instrument(self):
"""Freeze instrument function width."""
if self.kinetics.log_instrument_width is not None:
self.kinetics.log_instrument_width.requires_grad = False
[docs]
def unfreeze_instrument(self):
"""Unfreeze instrument function width."""
if self.kinetics.log_instrument_width is not None:
self.kinetics.log_instrument_width.requires_grad = True
[docs]
def get_parameter_groups(self, base_lr: float = 1e-3) -> List[Dict]:
"""
Get parameter groups with appropriate learning rates.
Kinetic parameters often benefit from different learning rates:
- Rate constants (log-space): can use larger steps
- Efficiencies: moderate steps
- Instrument width: small steps (often well-constrained)
Parameters
----------
base_lr : float
Base learning rate
Returns
-------
param_groups : list of dict
Parameter groups suitable for torch optimizers
"""
param_groups = [
{
'params': [self.kinetics.log_rate_constants],
'lr': base_lr,
'name': 'rate_constants'
},
]
if self.kinetics.log_instrument_width is not None:
param_groups.append({
'params': [self.kinetics.log_instrument_width],
'lr': base_lr * 0.1, # Very conservative for instrument width
'name': 'instrument_width'
})
return param_groups
[docs]
def print_parameters(self):
"""Print current kinetic parameters."""
self.kinetics.print_parameters()
[docs]
def plot_occupancies(
self,
path: str,
log_scale: bool = True,
show_kinetic_states: bool = False,
figsize: Tuple[int, int] = (10, 6),
title: Optional[str] = None
):
"""
Plot state occupancies over time.
Parameters
----------
path : str
Output path for the plot
log_scale : bool
Whether to use log scale for time axis
show_kinetic_states : bool
If True, shows all kinetic states. If False, shows mapped structural states.
figsize : tuple
Figure size
title : str or None
Custom title
"""
import matplotlib.pyplot as plt
with torch.no_grad():
if show_kinetic_states:
# Show raw kinetic populations
populations = self.kinetics().cpu().numpy() # [n_time, n_kinetic]
state_names = self.kinetics.states
else:
# Show structural occupancies
populations = self().T.cpu().numpy() # [n_time, n_structural]
state_names = [f'Struct {i}' for i in range(self.nstates)]
time_points = self.time.cpu().numpy()
plt.figure(figsize=figsize)
for i, name in enumerate(state_names):
plt.plot(time_points, populations[:, i], label=name, linewidth=2)
plt.xlabel('Time', fontsize=12)
plt.ylabel('Occupancy', fontsize=12)
plt.legend(fontsize=10, loc='best')
plt.grid(True, alpha=0.3)
if log_scale and (time_points > 0).any():
plt.xscale('log')
pos_times = time_points[time_points > 0]
if len(pos_times) > 0:
plt.xlim(left=pos_times.min() * 0.9)
if title is None:
title = f'Kinetic Occupancies: {self.kinetics.flow_chart}'
plt.title(title, fontsize=14)
plt.tight_layout()
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
if self.verbose:
print(f"Plot saved to: {path}")
[docs]
def plot_comparison(
self,
target_occupancies: torch.Tensor,
path: str,
log_scale: bool = True,
figsize: Tuple[int, int] = (12, 5)
):
"""
Plot comparison between current and target occupancies.
Useful for debugging refinement by comparing to known ground truth
or to unrestrained refinement results.
Parameters
----------
target_occupancies : torch.Tensor
Target occupancies, shape [n_states, n_timepoints]
path : str
Output path for the plot
log_scale : bool
Whether to use log scale for time axis
figsize : tuple
Figure size
"""
import matplotlib.pyplot as plt
with torch.no_grad():
current = self().cpu().numpy() # [n_states, n_time]
target = target_occupancies.cpu().numpy()
time_points = self.time.cpu().numpy()
n_states = current.shape[0]
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Left: current (kinetic model)
ax1 = axes[0]
for i in range(n_states):
ax1.plot(time_points, current[i], label=f'State {i}', linewidth=2)
ax1.set_xlabel('Time')
ax1.set_ylabel('Occupancy')
ax1.set_title('Kinetic Model')
ax1.legend()
ax1.grid(True, alpha=0.3)
if log_scale and (time_points > 0).any():
ax1.set_xscale('log')
# Right: target
ax2 = axes[1]
for i in range(min(n_states, target.shape[0])):
ax2.plot(time_points, target[i], label=f'State {i}', linewidth=2)
ax2.set_xlabel('Time')
ax2.set_ylabel('Occupancy')
ax2.set_title('Target')
ax2.legend()
ax2.grid(True, alpha=0.3)
if log_scale and (time_points > 0).any():
ax2.set_xscale('log')
plt.tight_layout()
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
[docs]
def state_dict_kinetics(self) -> Dict:
"""
Get state dict with kinetic parameters for saving/loading.
"""
return {
'flow_chart': self.kinetics.flow_chart,
'time': self.time,
'state_mapping': self.state_mapping,
'regularization': self.regularization,
'kinetics_state_dict': self.kinetics.state_dict(),
}
[docs]
@classmethod
def load_from_kinetics_state(cls, state: Dict, verbose: int = 1) -> 'occupancies_kinetics':
"""
Load model from saved kinetic state.
"""
model = cls(
flow_chart=state['flow_chart'],
time=state['time'],
state_mapping=state['state_mapping'],
regularization=state['regularization'],
verbose=verbose
)
model.kinetics.load_state_dict(state['kinetics_state_dict'])
return model
[docs]
class occupancies_kinetics_multiexperiment(DeviceMixin, nn.Module):
"""
Kinetics-constrained occupancy model for multiple experiments.
This handles the case where you have multiple datasets with different
conditions (e.g., different temperatures, different excitation wavelengths)
but want to share some kinetic parameters across them.
Parameters
----------
flow_chart : str
Kinetic scheme (shared across experiments)
experiments : list of dict
Each dict contains:
- 'time': time points for this experiment
- 'rate_constants': experiment-specific rate constants (or None to share)
- 'name': optional name for the experiment
shared_rates : list of str or None
List of transitions that should share rates across experiments.
E.g., ["A->B", "C->D"]. If None, all rates are experiment-specific.
verbose : int
Verbosity level
"""
[docs]
def __init__(
self,
flow_chart: str,
experiments: List[Dict],
shared_rates: Optional[List[str]] = None,
verbose: int = 1
):
super(occupancies_kinetics_multiexperiment, self).__init__()
self.flow_chart = flow_chart
self.n_experiments = len(experiments)
self.shared_rates = shared_rates or []
self.verbose = verbose
# Create separate kinetic models for each experiment
self.kinetic_models = nn.ModuleList()
for i, exp in enumerate(experiments):
model = occupancies_kinetics(
flow_chart=flow_chart,
time=exp['time'],
rate_constants=exp.get('rate_constants'),
verbose=verbose if i == 0 else 0
)
self.kinetic_models.append(model)
# Setup shared parameters if requested
if self.shared_rates:
self._setup_shared_rates()
def _setup_shared_rates(self):
"""Link shared rate parameters across experiments."""
# Use the first experiment's parameters as the shared ones
primary = self.kinetic_models[0].kinetics
for exp_idx in range(1, self.n_experiments):
secondary = self.kinetic_models[exp_idx].kinetics
for trans_idx, (from_s, to_s) in enumerate(primary.transitions):
trans_key = f"{from_s}->{to_s}"
if trans_key in self.shared_rates:
# Share this rate constant
# Note: This creates a shared reference, gradients will accumulate
if self.verbose:
print(f"Sharing rate {trans_key} across experiments")
[docs]
def forward(self, experiment_idx: int) -> torch.Tensor:
"""
Get occupancies for a specific experiment.
Parameters
----------
experiment_idx : int
Index of the experiment
Returns
-------
occupancies : torch.Tensor
Shape [n_states, n_timepoints]
"""
return self.kinetic_models[experiment_idx]()
[docs]
def forward_all(self) -> List[torch.Tensor]:
"""
Get occupancies for all experiments.
Returns
-------
occupancies_list : list of torch.Tensor
List of occupancy tensors, one per experiment
"""
return [model() for model in self.kinetic_models]