Source code for torchref.model.parameter_wrappers

"""
A file that contains wrapper classes for handling crystallographic parameters. (Occ, xyz, B, etc.)
"""

from typing import Optional, Union

import torch
from torch import nn

from torchref.utils.caching import CachedForwardMixin
from torchref.utils.device_mixin import DeviceMixin


class _AssembleMixedTensor(torch.autograd.Function):
    """Scatter refinable values into a clone of fixed_values, with a
    cheap (index_select) backward.

    PyTorch's default backward for ``result[idx] = refinable`` lowers to
    ``aten::_index_put_impl_``, whose backward goes through a radix-sort
    of the indices followed by an atomic scatter. Profiling on A100/1DAW
    showed this dominating the model SF backward (~370 µs/iter across
    six MixedTensors). The gradient w.r.t. ``refinable_params`` is just
    ``grad_output[indices]`` — a single ``index_select`` — so we wrap the
    assembly in a custom autograd op that returns exactly that.
    """

    @staticmethod
    def forward(ctx, refinable, fixed, indices):
        # fixed is a buffer; the .clone() is required so callers cannot
        # mutate it. index_copy_ is the canonical fast scatter for dim=0.
        result = fixed.clone()
        result.index_copy_(0, indices, refinable)
        ctx.save_for_backward(indices)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        (indices,) = ctx.saved_tensors
        # Only `refinable` (input 0) needs grad; fixed is a buffer and
        # indices is integer-typed.
        d_refinable = grad_output.index_select(0, indices)
        return d_refinable, None, None


[docs] class MixedTensor(DeviceMixin, CachedForwardMixin, nn.Module): """ A wrapper class for tensors with mixed fixed and refinable elements. Stores a mask indicating which elements can be refined and maintains both fixed and refinable components separately. The full tensor is reconstructed on-the-fly when accessed. Parameters ---------- initial_values : torch.Tensor, optional Initial tensor values for all elements. Optional for empty init. refinable_mask : torch.Tensor, optional Boolean mask indicating which elements can be refined. If None, all elements are refinable. requires_grad : bool, optional Whether refinable parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for the tensor. Default is same as initial_values. device : torch.device, optional Device for the tensor. Default is same as initial_values. name : str, optional Optional name for this parameter (useful for debugging/logging). Attributes ---------- refinable_mask : torch.Tensor Boolean mask indicating refinable elements. fixed_mask : torch.Tensor Boolean mask indicating fixed elements (inverse of refinable_mask). fixed_values : torch.Tensor Buffer containing fixed values. refinable_params : nn.Parameter Parameter containing refinable values. Examples -------- Empty initialization for state_dict loading:: mixed = MixedTensor() mixed.load_state_dict(torch.load('mixed.pt')) Full initialization with values:: mask = torch.zeros(100, dtype=torch.bool) mask[20:30] = True initial_values = torch.randn(100) mixed = MixedTensor(initial_values, refinable_mask=mask, requires_grad=True) optimizer = torch.optim.Adam([mixed.refinable_params], lr=0.01) """
[docs] def __init__( self, initial_values: torch.Tensor = None, refinable_mask: Optional[torch.Tensor] = None, requires_grad: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, name: Optional[str] = None, ): """ Initialize a MixedTensor. If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict(). Parameters ---------- initial_values : torch.Tensor, optional Initial tensor values for all elements. Optional for empty init. refinable_mask : torch.Tensor, optional Boolean mask indicating which elements can be refined. If None, all elements are refinable. requires_grad : bool, optional Whether refinable parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for the tensor. Default is same as initial_values. device : torch.device, optional Device for the tensor. Default is same as initial_values. name : str, optional Optional name for this parameter (useful for debugging/logging). """ super().__init__() # Store the name self._name = name # Empty initialization if initial_values is None: self.register_buffer("refinable_mask", None) self.register_buffer("fixed_mask", None) self.register_buffer("fixed_values", None) self.register_buffer("_shape", None) self.refinable_params = nn.Parameter( torch.empty(0), requires_grad=requires_grad ) self._has_refinable = False self._refinable_indices = None return if dtype is None: dtype = initial_values.dtype if device is None: device = initial_values.device initial_values = initial_values.to(dtype=dtype, device=device) # Create refinable mask if refinable_mask is None: # For multi-dimensional tensors, create a mask for the first dimension if initial_values.ndim > 1: refinable_mask = torch.ones( initial_values.shape[0], dtype=torch.bool, device=device ) else: refinable_mask = torch.ones_like(initial_values, dtype=torch.bool) else: refinable_mask = refinable_mask.to(device=device) # Validate mask shape - it should match the first dimension for broadcasting if initial_values.ndim > 1: # For multi-dimensional tensors, mask should be 1D matching first dimension if ( refinable_mask.ndim != 1 or refinable_mask.shape[0] != initial_values.shape[0] ): raise ValueError( f"For {initial_values.ndim}D tensor with shape {initial_values.shape}, " f"refinable_mask must be 1D with shape ({initial_values.shape[0]},), " f"got shape {refinable_mask.shape}" ) else: # For 1D tensors, shapes must match exactly if refinable_mask.shape != initial_values.shape: raise ValueError( f"refinable_mask shape {refinable_mask.shape} must match " f"initial_values shape {initial_values.shape}" ) # Store the mask as a buffer (not a parameter, won't be optimized) self.register_buffer("refinable_mask", refinable_mask) self.register_buffer("fixed_mask", ~refinable_mask) # Store fixed values as a buffer. Force row-major contiguous layout: # pandas DataFrame selections (used for xyz / u init) hand us # column-major (..., 3) arrays whose stride is (1, N). Downstream # consumers (e.g. Triton kernels) assume the canonical stride. fixed_values = initial_values.clone().detach().contiguous() self.register_buffer("fixed_values", fixed_values) # Store refinable values as a parameter (will be optimized) refinable_values = initial_values[refinable_mask].clone().detach() self.refinable_params = nn.Parameter( refinable_values, requires_grad=requires_grad ) # Store shape for reconstruction self.register_buffer("_shape", torch.tensor(initial_values.shape)) # Pre-compute index cache to avoid boolean indexing at runtime self._build_index_cache()
def _build_index_cache(self): """Pre-compute integer indices from refinable_mask to avoid GPU sync.""" if ( self.refinable_mask is not None and self.refinable_mask.numel() > 0 and self.refinable_params is not None and self.refinable_params.numel() > 0 ): self._has_refinable = bool(self.refinable_mask.any().item()) if self._has_refinable: # Keep the legacy tuple form for callers that read # ``_refinable_indices`` directly (used by ``__setitem__`` etc). self._refinable_indices = self.refinable_mask.nonzero(as_tuple=True) # Pre-compute a 1-D int64 index tensor for the fast path — # ``index_copy_`` / ``index_select`` take a 1-D LongTensor. self._refinable_idx_1d = self._refinable_indices[0] self._all_refinable = bool( self.refinable_mask.numel() == int(self.refinable_params.shape[0]) ) else: self._refinable_indices = None self._refinable_idx_1d = None self._all_refinable = False else: self._has_refinable = False self._refinable_indices = None self._refinable_idx_1d = None self._all_refinable = False
[docs] def forward(self) -> torch.Tensor: """ Reconstruct and return the full tensor. Three fast paths, in priority order: 1. **All atoms refinable** — return ``refinable_params`` directly (no clone, no scatter). Common in standard refinement where no atoms are frozen. Saves a full-tensor clone + an ``index_put_`` per call and replaces the ``index_put`` backward (sort + atomic scatter) with a no-op identity. 2. **No refinable atoms** — return ``fixed_values.clone()`` (the clone preserves the caller-must-not-mutate contract even though the result is detached from autograd). 3. **Mixed** — go through :class:`_AssembleMixedTensor`, whose backward is a single ``index_select`` (gather) instead of PyTorch's default ``index_put_`` backward (radix-sort + scatter). """ if self._all_refinable: # `.clone()` turns the Parameter into a plain Tensor (otherwise # the CachedForwardMixin trips ``nn.Module.__setattr__`` when # caching a Parameter under a non-Parameter attribute slot). # The clone's backward is identity, which is exactly the cheap # path we want — gradient flows straight to ``refinable_params`` # without going through PyTorch's index_put backward # (sort + atomic scatter). return self.refinable_params.clone() if not self._has_refinable or self.refinable_params.numel() == 0: return self.fixed_values.clone() return _AssembleMixedTensor.apply( self.refinable_params, self.fixed_values, self._refinable_idx_1d, )
[docs] def __getitem__(self, key) -> torch.Tensor: """ Get values at specified indices/mask from the full tensor. Parameters ---------- key : int, slice, torch.Tensor, or tuple Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing Returns ------- torch.Tensor Selected values from the full tensor. Examples -------- :: model.b[5] # Get B-factor for atom 5 model.b[5:10] # Get B-factors for atoms 5-9 model.b[mask] # Get B-factors where mask is True model.xyz[:, 0] # Get all x-coordinates Notes ----- Subclasses may override _get_values() to customize value retrieval. """ return self._get_values(key)
def _get_values(self, key) -> torch.Tensor: """ Internal method to get values. Override in subclasses for custom behavior. Parameters ---------- key : indexing key Index specification. Returns ------- torch.Tensor Selected values from the full tensor. """ return self()[key]
[docs] def __setitem__(self, key, value) -> None: """ Set values at specified indices/mask. This method updates both fixed_values and refinable_params at the specified positions. Supports various indexing styles including slices, boolean masks, and integer indices. Parameters ---------- key : int, slice, torch.Tensor, or tuple Index specification. Can be: - int: Single element - slice: Range of elements (e.g., 5:10, :, ::2) - torch.Tensor: Boolean mask or integer indices - tuple: Multi-dimensional indexing value : torch.Tensor, float, or int Values to assign. Can be: - Scalar: Broadcast to all selected positions - Tensor: Must match the shape of selected region Examples -------- :: model.b[:] = 30.0 # Set all B-factors to 30 model.b[5:10] = 25.0 # Set B-factors 5-9 to 25 model.b[mask] = new_values # Set B-factors where mask is True model.xyz[mask] = new_coords # Set coordinates for masked atoms model.xyz[:, 0] += 1.0 # Shift all x-coordinates (read-modify-write) Notes ----- This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state. Subclasses may override _set_values() to customize value handling (e.g., PositiveMixedTensor converts to log-space). """ # Convert value to tensor if needed if not isinstance(value, torch.Tensor): value = torch.tensor(value, dtype=self.dtype, device=self.device) else: value = value.to(dtype=self.dtype, device=self.device) # Delegate to _set_values (can be overridden by subclasses) self._set_values(key, value)
def _set_values(self, key, value: torch.Tensor) -> None: """ Internal method to set values. Override in subclasses for custom behavior. Parameters ---------- key : indexing key Index specification (already validated). value : torch.Tensor Values to assign (already converted to tensor with correct dtype/device). """ # Get current full tensor current_full = self.forward().detach() # Apply the assignment current_full[key] = value # Update fixed_values buffer with the new full tensor self.fixed_values = current_full.clone() # Re-extract refinable parameters (only those in refinable_mask) if self.refinable_mask.any(): new_refinable = current_full[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad )
[docs] def set(self, values: torch.Tensor, mask: torch.Tensor) -> None: """ Set values at positions specified by a boolean mask. Updates both fixed_values and refinable_params at the positions specified by the mask. This is useful for applying coordinate shifts, B-factor corrections, or any other updates to specific atoms. Parameters ---------- values : torch.Tensor New values to assign. Shape must match: - For 1D tensors: (n_selected,) where n_selected = mask.sum() - For 2D tensors (e.g., xyz): (n_selected, d) where d is the second dimension size (e.g., 3 for coordinates) mask : torch.Tensor Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values. Raises ------ ValueError If mask shape doesn't match tensor's first dimension, or if values shape doesn't match the number of selected elements. Examples -------- :: # Update coordinates for selected atoms mask = model.get_selection_mask("chain A") new_coords = original_coords[mask] + shift model.xyz.set(new_coords, mask) # Update B-factors for specific residues mask = model.get_selection_mask("resseq 10:20") new_b = torch.ones(mask.sum()) * 30.0 model.b.set(new_b, mask) Notes ----- This method modifies the tensor in-place. The refinable_params parameter is replaced with a new Parameter containing the updated values, which may affect optimizer state. """ # Validate mask shape if mask.shape[0] != self.shape[0]: raise ValueError( f"Mask shape {mask.shape} must match tensor's first dimension {self.shape[0]}" ) if mask.ndim != 1: raise ValueError(f"Mask must be 1D, got shape {mask.shape}") # Move mask to correct device mask = mask.to(device=self.device, dtype=torch.bool) # Validate values shape n_selected = mask.sum().item() expected_shape = ( (n_selected,) if len(self.shape) == 1 else (n_selected, self.shape[1]) ) if values.shape != expected_shape: raise ValueError( f"Values shape {values.shape} doesn't match expected shape {expected_shape} " f"for {n_selected} selected elements" ) # Get current full tensor current_full = self.forward().detach() # Update the full tensor at masked positions current_full[mask] = values.to(dtype=self.dtype, device=self.device) # Update fixed_values buffer with the new full tensor self.fixed_values = current_full.clone() # Re-extract refinable parameters (only those in refinable_mask) if self.refinable_mask.any(): new_refinable = current_full[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad )
@property def shape(self): """Return the shape of the full tensor.""" return tuple(self._shape.tolist()) @property def dtype(self): """Return the dtype of the tensor.""" return self.fixed_values.dtype @property def device(self): """Return the device of the tensor.""" return self.fixed_values.device
[docs] def get_refinable_count(self) -> int: """Return the number of refinable parameters.""" return self.refinable_mask.sum().item()
[docs] def get_fixed_count(self) -> int: """Return the number of fixed parameters.""" return self.fixed_mask.sum().item()
[docs] def update_fixed_values(self, new_values: torch.Tensor): """ Update the fixed values (does not affect refinable parameters). Parameters ---------- new_values : torch.Tensor New tensor values. Only fixed positions will be updated. Raises ------ ValueError If new_values shape doesn't match tensor shape. """ if new_values.shape != self.shape: raise ValueError( f"new_values shape {new_values.shape} must match " f"tensor shape {self.shape}" ) self.fixed_values = new_values.to(dtype=self.dtype, device=self.device).detach()
[docs] def update_refinable_mask( self, new_mask: torch.Tensor, reset_refinable: bool = False ): """ Update which elements are refinable. This is an advanced operation that modifies the refinable/fixed split. Parameters ---------- new_mask : torch.Tensor New boolean mask indicating refinable elements. reset_refinable : bool, optional If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False. """ if new_mask.shape[0] != self.shape[0]: raise ValueError( f"new_mask shape {new_mask.shape} must match " f"tensor shape {self.shape}" ) current_full = self.forward().detach() self.refinable_mask = new_mask.to(device=self.device) self.fixed_mask = ~new_mask if reset_refinable: # Reset refinable params to current fixed values self.fixed_values = current_full.clone() new_refinable = current_full[self.refinable_mask].clone() else: # Try to preserve existing refinable values where masks overlap new_refinable = current_full[self.refinable_mask].clone() # Replace the parameter self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad ) # Rebuild index cache after mask change self._build_index_cache()
[docs] def detach(self) -> torch.Tensor: """Return a detached copy of the full tensor.""" return self.forward().detach()
[docs] def clone(self) -> "MixedTensor": """Create a deep copy of this MixedTensor.""" new_mixed = MixedTensor( self.forward().detach(), self.refinable_mask.clone(), requires_grad=self.refinable_params.requires_grad, dtype=self.dtype, device=self.device, name=self.name, ) return new_mixed
[docs] def copy(self) -> "MixedTensor": """ Create a deep copy of this MixedTensor. Creates a complete independent copy with all buffers and parameters. Alias for clone(). Returns ------- MixedTensor New MixedTensor instance with copied data. """ return self.clone()
[docs] def clip(self, min_value=None, max_value=None) -> "MixedTensor": """Clip the full tensor values between min_value and max_value.""" full_tensor = self.forward() clipped_tensor = full_tensor if min_value is not None: clipped_tensor = torch.clamp(clipped_tensor, min=min_value) if max_value is not None: clipped_tensor = torch.clamp(clipped_tensor, max=max_value) new_mixed = MixedTensor( clipped_tensor.detach(), self.refinable_mask.clone(), requires_grad=self.refinable_params.requires_grad, dtype=self.dtype, device=self.device, name=self.name, ) return new_mixed
[docs] def to(self, *args, **kwargs): """Move via :class:`DeviceMixin` and rebuild the index cache.""" result = super().to(*args, **kwargs) result._build_index_cache() return result
def parameters(self): parameter = super().parameters() parameter_valid = [param for param in parameter if param.numel() > 0] yield from parameter_valid
[docs] def refine( self, selection: Union[slice, torch.Tensor, tuple], reset_values: bool = False ): """ Make a selection of the tensor refinable. Parameters ---------- selection : slice, torch.Tensor, or tuple Selection indicating which elements should become refinable. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices reset_values : bool, optional If True, reset the selected elements to their current fixed values before making them refinable. Default is False. Examples -------- :: mixed.refine(slice(10, 20)) # Make elements 10-19 refinable mixed.refine(mask) # Make elements where mask is True refinable """ # Get current full tensor current_full = self.forward().detach() # Create a new mask that combines old refinable + new selection new_mask = self.refinable_mask.clone() if isinstance(selection, torch.Tensor): if selection.dtype == torch.bool: # For multi-dimensional tensors, mask should match first dimension # For 1D tensors, must match exactly if len(self.shape) > 1: if selection.shape[0] != self.shape[0] or len(selection.shape) != 1: raise ValueError( f"Boolean selection shape {selection.shape} must be 1D " f"matching first dimension {self.shape[0]} for multi-dimensional " f"tensor with shape {self.shape}" ) else: if selection.shape != self.shape: raise ValueError( f"Boolean selection shape {selection.shape} must match " f"tensor shape {self.shape}" ) new_mask |= selection.to(device=self.device) else: # Integer indices temp_mask = torch.zeros_like(new_mask) temp_mask[selection] = True new_mask |= temp_mask else: # Handle slice or tuple indices temp_mask = torch.zeros_like(new_mask) temp_mask[selection] = True new_mask |= temp_mask # Update the mask self.refinable_mask = new_mask self.fixed_mask = ~new_mask # Update values if reset_values: self.fixed_values = current_full.clone() # Reconstruct refinable parameters with new selection new_refinable = current_full[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad ) # Rebuild index cache after mask change self._build_index_cache()
[docs] def fix( self, selection: Union[slice, torch.Tensor, tuple], freeze_at_current: bool = True, ): """ Make a selection of the tensor fixed (non-refinable). Parameters ---------- selection : slice, torch.Tensor, or tuple Selection indicating which elements should become fixed. Can be: - Boolean tensor of same shape as the full tensor - Slice object (e.g., slice(10, 20)) - Tuple of indices for multidimensional tensors - Integer indices freeze_at_current : bool, optional If True (default), freeze the selected elements at their current values. If False, they revert to the original fixed values. Examples -------- :: mixed.fix(slice(10, 20)) # Fix elements 10-19 mixed.fix(mask) # Fix elements where mask is True """ # Get current full tensor current_full = self.forward().detach() # Create a new mask that removes the selection from refinable new_mask = self.refinable_mask.clone() if isinstance(selection, torch.Tensor): if selection.dtype == torch.bool: # For multi-dimensional tensors, mask should match first dimension # For 1D tensors, must match exactly if len(self.shape) > 1: if selection.shape[0] != self.shape[0] or len(selection.shape) != 1: raise ValueError( f"Boolean selection shape {selection.shape} must be 1D " f"matching first dimension {self.shape[0]} for multi-dimensional " f"tensor with shape {self.shape}" ) else: if selection.shape != self.shape: raise ValueError( f"Boolean selection shape {selection.shape} must match " f"tensor shape {self.shape}" ) new_mask &= ~selection.to(device=self.device) else: # Integer indices temp_mask = torch.zeros_like(new_mask) temp_mask[selection] = True new_mask &= ~temp_mask else: # Handle slice or tuple indices temp_mask = torch.zeros_like(new_mask) temp_mask[selection] = True new_mask &= ~temp_mask # Update the mask self.refinable_mask = new_mask self.fixed_mask = ~new_mask # Update fixed values if freeze_at_current: self.fixed_values = current_full.clone() # Reconstruct refinable parameters without the fixed selection if self.refinable_mask.any(): new_refinable = current_full[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad ) else: # All fixed, create empty parameter self.refinable_params = nn.Parameter( torch.tensor([], dtype=self.dtype, device=self.device), requires_grad=self.refinable_params.requires_grad, ) # Rebuild index cache after mask change self._build_index_cache()
[docs] def refine_all(self): """Make all elements refinable.""" all_true = torch.ones_like(self.refinable_mask) self.refine(all_true)
[docs] def fix_all(self, freeze_at_current: bool = True): """Make all elements fixed.""" all_true = torch.ones_like(self.refinable_mask) self.fix(all_true, freeze_at_current=freeze_at_current)
@property def name(self) -> Optional[str]: """Return the name of this parameter.""" return self._name @name.setter def name(self, value: str): """Set the name of this parameter.""" self._name = value def __repr__(self) -> str: name_str = f"'{self.name}', " if self.name is not None else "" return ( f"MixedTensor({name_str}shape={self.shape}, dtype={self.dtype}, " f"device={self.device}, refinable={self.get_refinable_count()}, " f"fixed={self.get_fixed_count()})" )
[docs] def __str__(self) -> str: """More detailed string representation.""" name_str = f" '{self.name}'" if self.name is not None else "" return ( f"MixedTensor{name_str}:\n" f" Shape: {self.shape}\n" f" Dtype: {self.dtype}\n" f" Device: {self.device}\n" f" Refinable: {self.get_refinable_count()} / {self.refinable_mask.numel()}\n" f" Fixed: {self.get_fixed_count()} / {self.refinable_mask.numel()}\n" f" Requires grad: {self.refinable_params.requires_grad}" )
[docs] def parameters(self): """Return refinable parameters for optimizer.""" yield self.refinable_params
[docs] class PositiveMixedTensor(MixedTensor): """ A MixedTensor subclass ensuring all values are positive via log-space parametrization. Useful for parameters that must be strictly positive (e.g., B-factors, scale factors, sigma values). Values are stored as logarithms internally and converted to normal space via exp() when accessed. Reparametrization:: internal_value = log(desired_value) output_value = exp(internal_value) This ensures output_value > 0 always, with smooth gradient flow. Parameters ---------- initial_values : torch.Tensor, optional Initial tensor values in NORMAL space. Optional for empty init. refinable_mask : torch.Tensor, optional Boolean mask indicating which elements can be refined. requires_grad : bool, optional Whether refinable parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for the tensor. device : torch.device, optional Device for the tensor. name : str, optional Optional name for this parameter. epsilon : float, optional Small value to add before taking log to avoid log(0). Default is 1e-1. Examples -------- Empty initialization for state_dict loading:: b = PositiveMixedTensor() b.load_state_dict(torch.load('b_factors.pt')) Full initialization with values:: initial_b = torch.tensor([20.0, 30.0, 15.0]) b = PositiveMixedTensor(initial_b) output = b() # Returns exp(log_b) = positive values assert (b() > 0).all() """
[docs] def __init__( self, initial_values: torch.Tensor = None, refinable_mask: Optional[torch.Tensor] = None, requires_grad: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, name: Optional[str] = None, epsilon: float = 1e-1, ): """ Initialize a PositiveMixedTensor. If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict(). Parameters ---------- initial_values : torch.Tensor, optional Initial tensor values in NORMAL space. Optional for empty init. refinable_mask : torch.Tensor, optional Boolean mask indicating which elements can be refined. requires_grad : bool, optional Whether refinable parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for the tensor. device : torch.device, optional Device for the tensor. name : str, optional Optional name for this parameter. epsilon : float, optional Small value to add before taking log to avoid log(0). Default is 1e-1. Raises ------ ValueError If any initial values are not positive. """ # Store epsilon self.epsilon = epsilon # Empty initialization if initial_values is None: super().__init__(None, refinable_mask, requires_grad, dtype, device, name) return # Full initialization - clip initial values to be positive initial_values = torch.clamp(initial_values, min=epsilon) # Store epsilon as buffer (not parameter) self.epsilon = epsilon # Convert initial values to log space log_initial_values = torch.log(initial_values.clamp(min=epsilon)) # Initialize parent class with log-space values super().__init__( initial_values=log_initial_values, refinable_mask=refinable_mask, requires_grad=requires_grad, dtype=dtype, device=device, name=name, )
[docs] def forward(self) -> torch.Tensor: """ Return the full tensor in NORMAL space. Applies exponential transformation to the log-space values. Returns ------- torch.Tensor Tensor with positive values. """ # Get log-space values from parent log_values = super().forward() # Convert to normal space via exp return torch.exp(log_values)
def _set_values(self, key, value: torch.Tensor) -> None: """ Internal method to set values with log-space conversion. Values are provided in NORMAL space (positive values) and automatically converted to log-space for internal storage. Parameters ---------- key : indexing key Index specification (already validated). value : torch.Tensor Values to assign in NORMAL space. Must be positive. Raises ------ ValueError If any values are not positive. """ # Ensure values are positive if (value <= 0).any(): raise ValueError("All values must be positive for PositiveMixedTensor") # Get current full tensor in NORMAL space current_normal = self.forward().detach() # Update in normal space first current_normal[key] = value # Convert entire tensor to log space current_log = torch.log(current_normal.clamp(min=self.epsilon)) # Update fixed_values buffer with the new log values self.fixed_values = current_log.clone() # Re-extract refinable parameters (in log space) if self.refinable_mask.any(): new_refinable = current_log[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad )
[docs] def fix(self, mask: torch.Tensor, freeze_at_current: bool = True): """ Fix (freeze) specific elements. Converts current normal-space values to log space for storage. Parameters ---------- mask : torch.Tensor Boolean mask indicating which elements to fix. freeze_at_current : bool, optional If True, freeze at current values. Default is True. """ if freeze_at_current: # Get current log-space values WITHOUT creating computation graph with torch.no_grad(): current_normal = self.forward() current_log = torch.log(current_normal.clamp(min=self.epsilon)) # Update fixed_values with current log-space values if current_log.ndim > 1: self.fixed_values[mask] = current_log[mask] else: self.fixed_values = torch.where(mask, current_log, self.fixed_values) # Call parent's fix method with freeze_at_current=False since we already updated super().fix(mask, freeze_at_current=False)
[docs] def refine(self, mask: torch.Tensor): """ Make specific elements refinable. Preserves current log-space values. Parameters ---------- mask : torch.Tensor Boolean mask indicating which elements to make refinable. """ # Get current log-space values WITHOUT creating computation graph with torch.no_grad(): current_normal = self.forward() current_log = torch.log(current_normal.clamp(min=self.epsilon)) # Update fixed_values with current log-space values if current_log.ndim > 1: self.fixed_values[mask] = current_log[mask] else: self.fixed_values = torch.where(mask, current_log, self.fixed_values) # Call parent's refine method super().refine(mask)
[docs] def set(self, values: torch.Tensor, mask: torch.Tensor) -> None: """ Set values at positions specified by a boolean mask. Values are provided in NORMAL space (e.g., actual B-factors) and automatically converted to log-space for internal storage. Parameters ---------- values : torch.Tensor New values to assign in NORMAL space (positive values). Shape must be (n_selected,) where n_selected = mask.sum(). mask : torch.Tensor Boolean mask of shape (n_atoms,) indicating which elements to update. True positions will receive the new values. Raises ------ ValueError If mask shape doesn't match tensor's first dimension, if values shape doesn't match the number of selected elements, or if any values are not positive. Examples -------- :: # Update B-factors for selected atoms mask = model.get_selection_mask("name CA") new_b = torch.ones(mask.sum()) * 30.0 # Set CA B-factors to 30 model.b.set(new_b, mask) Notes ----- This method modifies the tensor in-place. Values are automatically converted to log-space internally to maintain the positivity constraint. """ # Validate mask shape if mask.shape[0] != self.shape[0]: raise ValueError( f"Mask shape {mask.shape} must match tensor's first dimension {self.shape[0]}" ) if mask.ndim != 1: raise ValueError(f"Mask must be 1D, got shape {mask.shape}") # Move mask to correct device mask = mask.to(device=self.device, dtype=torch.bool) # Validate values shape n_selected = mask.sum().item() expected_shape = (n_selected,) if values.shape != expected_shape: raise ValueError( f"Values shape {values.shape} doesn't match expected shape {expected_shape} " f"for {n_selected} selected elements" ) # Ensure values are positive values = values.to(dtype=self.dtype, device=self.device) if (values <= 0).any(): raise ValueError("All values must be positive for PositiveMixedTensor") # Convert values to log space log_values = torch.log(values.clamp(min=self.epsilon)) # Get current full tensor in LOG space current_log = super().forward().detach() # Update the log tensor at masked positions current_log[mask] = log_values # Update fixed_values buffer with the new log values self.fixed_values = current_log.clone() # Re-extract refinable parameters (in log space) if self.refinable_mask.any(): new_refinable = current_log[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad )
[docs] def get_log_values(self) -> torch.Tensor: """ Return the internal log-space representation. Useful for debugging or when direct access to the parametrization space is needed. Returns ------- torch.Tensor Tensor with log-space values. """ return super().forward()
[docs] def update_refinable_mask( self, new_mask: torch.Tensor, reset_refinable: bool = False ): """ Update which elements are refinable. Properly handles log-space conversion. Parameters ---------- new_mask : torch.Tensor New boolean mask indicating refinable elements. reset_refinable : bool, optional If True, reset refinable parameters to current fixed values. If False, keep existing refinable parameter values where possible. Default is False. """ if new_mask.shape[0] != self.shape[0]: raise ValueError( f"new_mask shape {new_mask.shape} must match " f"tensor shape {self.shape}" ) # Get current values in NORMAL space with torch.no_grad(): current_normal = self.forward() # Convert to log space current_log = torch.log(current_normal.clamp(min=self.epsilon)) self.refinable_mask = new_mask.to(device=self.device) self.fixed_mask = ~new_mask # Update fixed_values with log-space values self.fixed_values = current_log.clone() # Extract refinable portion (in log space) new_refinable_log = current_log[self.refinable_mask].clone() # Replace the parameter with log-space values self.refinable_params = nn.Parameter( new_refinable_log, requires_grad=self.refinable_params.requires_grad ) # Rebuild index cache after mask change self._build_index_cache()
[docs] def copy(self) -> "PositiveMixedTensor": """ Create a deep copy of this PositiveMixedTensor. Properly handles the log-space reparametrization. Returns ------- PositiveMixedTensor New PositiveMixedTensor instance with copied data. """ # Get current values in normal space current_normal = self.forward().detach() # Create new instance (will convert to log space internally) new_tensor = PositiveMixedTensor( initial_values=current_normal, refinable_mask=self.refinable_mask.clone(), requires_grad=self.refinable_params.requires_grad, dtype=self.dtype, device=self.device, name=self._name, epsilon=self.epsilon, ) return new_tensor
def __repr__(self) -> str: name_str = f"'{self.name}', " if self.name is not None else "" return ( f"PositiveMixedTensor({name_str}shape={self.shape}, dtype={self.dtype}, " f"device={self.device}, refinable={self.get_refinable_count()}, " f"fixed={self.get_fixed_count()}, epsilon={self.epsilon})" )
[docs] def __str__(self) -> str: """More detailed string representation.""" name_str = f" '{self.name}'" if self.name is not None else "" return ( f"PositiveMixedTensor{name_str}:\n" f" Shape: {self.shape}\n" f" Dtype: {self.dtype}\n" f" Device: {self.device}\n" f" Refinable: {self.get_refinable_count()} / {self.refinable_mask.numel()}\n" f" Fixed: {self.get_fixed_count()} / {self.refinable_mask.numel()}\n" f" Requires grad: {self.refinable_params.requires_grad}\n" f" Parametrization: log space (output = exp(internal))\n" f" Epsilon: {self.epsilon}" )
[docs] class CholeskyMixedTensor(MixedTensor): """A MixedTensor for anisotropic ADPs (U tensors) kept positive-definite. The six U components (u11, u22, u33, u12, u13, u23) are stored internally as the six free parameters of a lower-triangular Cholesky factor ``L``, and the public tensor is reconstructed as ``U = L Láµ€``. With the diagonal of ``L`` mapped through ``exp(x) + epsilon`` (strictly positive), ``U`` is positive- definite by construction for *any* value of the free parameters -- so unconstrained optimisation (e.g. LBFGS line search) can never drive ``U`` indefinite. An indefinite ``U`` otherwise makes the per-atom anisotropic B-matrix singular, so its inverse and the Gaussian exponent blow up and the structure-factor FFT returns NaN. This is the anisotropic analogue of :class:`PositiveMixedTensor`, which keeps isotropic B positive the same way. Rows that are entirely non-finite (isotropic atoms carry ``U = NaN``) are passed through unchanged in both directions, preserving the iso/aniso split. Notes ----- The eigen-decomposition / Cholesky needed to map ``U -> L`` runs only at construction and on freeze/unfreeze; the forward (hot) path is just ``exp`` and a handful of products, so gradients flow cleanly to ``refinable_params`` with no matrix factorisation in the autograd graph. """
[docs] def __init__( self, initial_values: torch.Tensor = None, refinable_mask: Optional[torch.Tensor] = None, requires_grad: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, name: Optional[str] = None, epsilon: float = 1e-3, ): self.epsilon = epsilon if initial_values is None: super().__init__(None, refinable_mask, requires_grad, dtype, device, name) return raw = self._u6_to_raw6(initial_values) super().__init__( initial_values=raw, refinable_mask=refinable_mask, requires_grad=requires_grad, dtype=dtype, device=device, name=name, )
# ------------------------------------------------------------------ # U (6-vector) <-> Cholesky free-parameter (6-vector) transforms. # Both operate on (..., 6) tensors and pass NaN rows through untouched. # ------------------------------------------------------------------ @staticmethod def _u6_to_matrix(U: torch.Tensor) -> torch.Tensor: M = U.new_zeros(*U.shape[:-1], 3, 3) M[..., 0, 0] = U[..., 0] M[..., 1, 1] = U[..., 1] M[..., 2, 2] = U[..., 2] M[..., 0, 1] = M[..., 1, 0] = U[..., 3] M[..., 0, 2] = M[..., 2, 0] = U[..., 4] M[..., 1, 2] = M[..., 2, 1] = U[..., 5] return M def _u6_to_raw6(self, U: torch.Tensor) -> torch.Tensor: """U components -> Cholesky free parameters [log(L_ii - eps); L_offdiag].""" eps = self.epsilon finite = torch.isfinite(U).all(dim=-1) M = self._u6_to_matrix(torch.nan_to_num(U, nan=0.0)) eye = torch.eye(3, dtype=M.dtype, device=M.device).expand_as(M) M = torch.where(finite[..., None, None], M, eye) # Project to positive-definite: symmetrise, clamp eigenvalues off zero. # No-op for well-conditioned deposited U; rescues marginally non-PD input. M = 0.5 * (M + M.transpose(-1, -2)) w, V = torch.linalg.eigh(M) w = w.clamp(min=eps * eps) M = (V * w.unsqueeze(-2)) @ V.transpose(-1, -2) L = torch.linalg.cholesky(M) diag = torch.stack([L[..., 0, 0], L[..., 1, 1], L[..., 2, 2]], dim=-1) off = torch.stack([L[..., 1, 0], L[..., 2, 0], L[..., 2, 1]], dim=-1) raw_diag = torch.log((diag - eps).clamp(min=1e-12)) # invert exp(x)+eps raw = torch.cat([raw_diag, off], dim=-1) nan = torch.full_like(raw, float("nan")) return torch.where(finite.unsqueeze(-1), raw, nan) def _raw6_to_u6(self, raw: torch.Tensor) -> torch.Tensor: """Cholesky free parameters -> U components (U = L Láµ€). PD by construction.""" eps = self.epsilon diag, off = raw[..., :3], raw[..., 3:] L11 = torch.exp(diag[..., 0]) + eps L22 = torch.exp(diag[..., 1]) + eps L33 = torch.exp(diag[..., 2]) + eps L21, L31, L32 = off[..., 0], off[..., 1], off[..., 2] U11 = L11 * L11 U22 = L21 * L21 + L22 * L22 U33 = L31 * L31 + L32 * L32 + L33 * L33 U12 = L21 * L11 U13 = L31 * L11 U23 = L31 * L21 + L32 * L22 return torch.stack([U11, U22, U33, U12, U13, U23], dim=-1)
[docs] def forward(self) -> torch.Tensor: """Return the full U tensor (positive-definite per finite row).""" return self._raw6_to_u6(super().forward())
def _set_values(self, key, value: torch.Tensor) -> None: """Set U-space values at ``key``; stored internally as Cholesky params.""" current = self.forward().detach() current[key] = value raw = self._u6_to_raw6(current) self.fixed_values = raw.clone() if self.refinable_mask.any(): self.refinable_params = nn.Parameter( raw[self.refinable_mask].clone(), requires_grad=self.refinable_params.requires_grad, )
[docs] def fix(self, mask: torch.Tensor, freeze_at_current: bool = True): """Freeze rows, storing their current value in Cholesky space.""" if freeze_at_current: with torch.no_grad(): raw = self._u6_to_raw6(self.forward()) self.fixed_values[mask] = raw[mask] super().fix(mask, freeze_at_current=False)
[docs] def refine(self, mask: torch.Tensor): """Make rows refinable, preserving their current value in Cholesky space.""" with torch.no_grad(): raw = self._u6_to_raw6(self.forward()) self.fixed_values[mask] = raw[mask] super().refine(mask)
[docs] def set(self, values: torch.Tensor, mask: torch.Tensor) -> None: """Set U-space values for masked rows (converted to Cholesky internally).""" self._set_values(mask, values)
[docs] def update_refinable_mask( self, new_mask: torch.Tensor, reset_refinable: bool = False ): """Repartition refinable/fixed elements, preserving values in U space. The base implementation re-stores ``forward()`` output directly, which would double-transform here (U written back into Cholesky-parameter storage); convert to Cholesky parameters first, mirroring :meth:`PositiveMixedTensor.update_refinable_mask`. """ if new_mask.shape[0] != self.shape[0]: raise ValueError( f"new_mask shape {new_mask.shape} must match tensor shape {self.shape}" ) with torch.no_grad(): current_raw = self._u6_to_raw6(self.forward()) self.refinable_mask = new_mask.to(device=self.device) self.fixed_mask = ~new_mask self.fixed_values = current_raw.clone() new_refinable = current_raw[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad ) self._build_index_cache()
[docs] def copy(self) -> "CholeskyMixedTensor": """Deep-copy, preserving the Cholesky parametrization. Rebuilds from the U-space values (``__init__`` reconverts to Cholesky parameters), so the copy stays positive-definite rather than degrading to a plain unconstrained :class:`MixedTensor`. """ return CholeskyMixedTensor( initial_values=self.forward().detach(), refinable_mask=self.refinable_mask.clone(), requires_grad=self.refinable_params.requires_grad, dtype=self.dtype, device=self.device, name=self._name, epsilon=self.epsilon, )
[docs] class OccupancyTensor(MixedTensor): """ A specialized MixedTensor for handling occupancy parameters in crystallographic refinement. Handles specific constraints and requirements for occupancy including value bounds [0, 1] via sigmoid reparameterization, atom sharing, and alternative conformation sum-to-1 constraints. Features: - Values bounded between 0 and 1 using sigmoid reparameterization - Atoms can share occupancies (e.g., all atoms in a residue) - Alternative conformations automatically sum to 1.0 via normalization - Memory-efficient collapsed storage (one parameter per sharing group) - Fully vectorized collapse/expand operations Parameters ---------- initial_values : torch.Tensor, optional Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init. sharing_groups : torch.Tensor, optional Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy. altloc_groups : list of tuple, optional List of tuples of atom index lists representing alternative conformations. Each tuple contains the atom indices for each conformation. refinable_mask : torch.Tensor, optional Boolean mask for which ATOMS can be refined (in full tensor space). requires_grad : bool, optional Whether refinable parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for the tensor. device : torch.device, optional Device for the tensor. name : str, optional Optional name for this parameter. use_sigmoid : bool, optional If True, use sigmoid parameterization to bound values to [0,1]. Default is True. Attributes ---------- expansion_mask : torch.Tensor Maps atoms to collapsed indices. linked_occ_sizes : list List of altloc group sizes present. collapse_counts : torch.Tensor Count of atoms per collapsed index. Examples -------- :: sharing_groups = torch.tensor([0, 0, 1, 1, 2, 2]) occ = OccupancyTensor( initial_values=torch.tensor([1.0, 1.0, 0.7, 0.7, 0.3, 0.3]), sharing_groups=sharing_groups, altloc_groups=[([2, 3], [4, 5])], ) result = occ() # Atoms 2-3 and 4-5 will sum to 1.0 """
[docs] def __init__( self, initial_values: torch.Tensor = None, sharing_groups: Optional[torch.Tensor] = None, altloc_groups: Optional[list] = None, refinable_mask: Optional[torch.Tensor] = None, requires_grad: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, name: Optional[str] = None, use_sigmoid: bool = True, ): """ Initialize an OccupancyTensor with collapsed storage and altloc support. If initial_values is provided, fully initializes the tensor. If not provided (empty init), creates a shell ready for load_state_dict(). Parameters ---------- initial_values : torch.Tensor, optional Initial occupancy values for ALL atoms (should be in [0, 1]). Optional for empty init. sharing_groups : torch.Tensor, optional Tensor of shape (n_atoms,) where each value is the collapsed index for that atom. If None, each atom has independent occupancy. Example: tensor([0, 0, 0, 1, 1, 2]) means atoms 0,1,2 share one occupancy, atoms 3,4 share another, and atom 5 is independent. altloc_groups : list of tuple, optional List of tuples of atom index lists representing alternative conformations. Example: [([10,11], [12,13])] means atoms 10,11 (conf A) and 12,13 (conf B) are altlocs that sum to 1.0. refinable_mask : torch.Tensor, optional Boolean mask for which ATOMS can be refined (in full tensor space). If any atom in a group is refinable, the entire group becomes refinable. requires_grad : bool, optional Whether refinable parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for the tensor. device : torch.device, optional Device for the tensor. name : str, optional Optional name for this parameter. use_sigmoid : bool, optional If True, use sigmoid parameterization to bound values to [0,1]. Default is True. """ # Store configuration self.use_sigmoid = use_sigmoid # Initialize Module first (required before register_buffer) nn.Module.__init__(self) self._name = name or "occupancy" # Empty initialization if initial_values is None: self._full_shape = 0 self._collapsed_shape = 0 self.register_buffer("refinable_mask", None) self.register_buffer("fixed_mask", None) self.register_buffer("fixed_values", None) self.register_buffer("_shape", None) self.register_buffer("expansion_mask", None) self.refinable_params = nn.Parameter( torch.empty(0), requires_grad=requires_grad ) self._has_refinable = False self._refinable_indices = None return # Full initialization self._full_shape = initial_values.shape[0] if dtype is None: dtype = initial_values.dtype if device is None: device = initial_values.device # Ensure initial_values is on the correct device and dtype initial_values = initial_values.to(device=device, dtype=dtype) # Validate initial values are in valid range if self.use_sigmoid: if torch.any(initial_values < 0) or torch.any(initial_values > 1): raise ValueError("Initial occupancy values must be in range [0, 1]") # Process sharing groups and altlocs, create expansion mask self._setup_sharing_groups_and_expansion( initial_values, sharing_groups, altloc_groups, device ) # Convert initial occupancies to logit space (full space) if self.use_sigmoid: clamped_values = torch.clamp(initial_values, min=1e-6, max=1 - 1e-6) logit_values = torch.logit(clamped_values) else: logit_values = initial_values.clone() # Collapse logit values using vectorized operation collapsed_logits = self._collapse_values_vectorized(logit_values) # Handle refinable mask - collapse it too if refinable_mask is not None: if refinable_mask.shape[0] != self._full_shape: raise ValueError( f"refinable_mask shape {refinable_mask.shape} must match " f"initial_values shape {initial_values.shape}" ) # Collapse the refinable mask collapsed_refinable_mask = self._collapse_mask_vectorized( refinable_mask.to(device=device) ) else: collapsed_refinable_mask = torch.ones( self._collapsed_shape, dtype=torch.bool, device=device ) # Note: With the new normalization approach, all altloc members are refinable # The sum-to-1 constraint is enforced during forward() via normalization # Store masks as buffers self.register_buffer("refinable_mask", collapsed_refinable_mask) self.register_buffer("fixed_mask", ~collapsed_refinable_mask) # Store fixed values as buffer (collapsed) self.register_buffer("fixed_values", collapsed_logits.clone().detach()) # Store refinable values as parameter (collapsed, excluding placeholders) refinable_values = collapsed_logits[collapsed_refinable_mask].clone().detach() self.refinable_params = nn.Parameter( refinable_values, requires_grad=requires_grad ) # Store collapsed shape self.register_buffer("_shape", torch.tensor([self._collapsed_shape])) # Pre-compute index cache to avoid boolean indexing at runtime self._build_index_cache()
def _setup_sharing_groups_and_expansion( self, initial_values: torch.Tensor, sharing_groups: Optional[torch.Tensor], altloc_groups: Optional[list], device: torch.device, ): """ Setup sharing groups, altlocs, and create expansion mask. Creates the index tensor for efficient collapse/expand operations and processes altloc groups into tensors grouped by number of conformations. Parameters ---------- initial_values : torch.Tensor Initial occupancy values for all atoms. sharing_groups : torch.Tensor or None Tensor of shape (n_atoms,) giving collapsed index for each atom. altloc_groups : list or None List of tuples of atom index lists for alternative conformations. device : torch.device Device to place tensors on. """ n_atoms = initial_values.shape[0] # Use sharing_groups directly as the expansion mask if sharing_groups is None: # No sharing - each atom maps to its own index expansion_mask = torch.arange(n_atoms, dtype=torch.long, device=device) self._collapsed_shape = n_atoms else: # Use the provided index tensor expansion_mask = sharing_groups.to(device=device, dtype=torch.long) self._collapsed_shape = expansion_mask.max().item() + 1 self.register_buffer("expansion_mask", expansion_mask) # Process altloc groups: convert to collapsed indices and group by size # linked_occupancies[n] = tensor of shape (N_groups, n) where n is number of conformations linked_occupancies = {} if altloc_groups is not None and len(altloc_groups) > 0: for altloc_idx, conf_groups in enumerate(altloc_groups): n_conformations = len(conf_groups) if n_conformations < 2: raise ValueError( f"Altloc group {altloc_idx} must have at least 2 conformations" ) # Get collapsed indices for each conformation collapsed_indices = [] for conf_atoms in conf_groups: if isinstance(conf_atoms, (list, tuple)): conf_atoms = torch.tensor( conf_atoms, dtype=torch.long, device=device ) else: conf_atoms = conf_atoms.to(device=device, dtype=torch.long) # Get collapsed index for first atom collapsed_idx = expansion_mask[conf_atoms[0]].item() # ASSERT: All atoms in this conformation map to the same collapsed index for atom_idx in conf_atoms: atom_collapsed_idx = expansion_mask[atom_idx].item() if atom_collapsed_idx != collapsed_idx: raise AssertionError( f"Altloc group {altloc_idx}, conformation {len(collapsed_indices)}: " f"atom {atom_idx} maps to collapsed index {atom_collapsed_idx}, " f"but first atom maps to {collapsed_idx}. " f"All atoms in a conformation must share the same collapsed index." ) collapsed_indices.append(collapsed_idx) # Add to the appropriate group based on number of conformations if n_conformations not in linked_occupancies: linked_occupancies[n_conformations] = [] linked_occupancies[n_conformations].append(collapsed_indices) # Convert lists to tensors and register as buffers # Store as dictionary with keys like 'linked_occ_2', 'linked_occ_3', etc. for n_conf, groups in linked_occupancies.items(): # Shape: (N_groups, n_conf) tensor = torch.tensor(groups, dtype=torch.long, device=device) self.register_buffer(f"linked_occ_{n_conf}", tensor) # Store which sizes we have self.linked_occ_sizes = sorted(linked_occupancies.keys()) # Create count buffer for vectorized collapse operations # counts[i] = number of atoms that map to collapsed index i counts = torch.zeros(self._collapsed_shape, dtype=torch.long, device=device) counts.scatter_add_(0, expansion_mask, torch.ones_like(expansion_mask)) self.register_buffer("collapse_counts", counts) def _collapse_values_vectorized(self, full_values: torch.Tensor) -> torch.Tensor: """ Collapse full tensor to collapsed storage using vectorized scatter_add. Parameters ---------- full_values : torch.Tensor Tensor in full space (one value per atom). Returns ------- torch.Tensor Tensor in collapsed space (one value per group + ungrouped atoms). """ # Sum values at each collapsed index using scatter_add collapsed_sum = torch.zeros( self._collapsed_shape, dtype=full_values.dtype, device=full_values.device ) collapsed_sum.scatter_add_(0, self.expansion_mask, full_values) # Divide by counts to get mean (avoid division by zero) collapsed = collapsed_sum / self.collapse_counts.float().clamp(min=1) return collapsed def _collapse_mask_vectorized(self, full_mask: torch.Tensor) -> torch.Tensor: """ Collapse boolean mask to collapsed storage using vectorized operations. If ANY atom in a collapsed position is refinable, the position is refinable. Parameters ---------- full_mask : torch.Tensor Boolean mask in full space. Returns ------- torch.Tensor Boolean mask in collapsed space. """ # Use scatter_add with float tensors, then check if any > 0 collapsed_sum = torch.zeros( self._collapsed_shape, dtype=torch.float, device=full_mask.device ) collapsed_sum.scatter_add_(0, self.expansion_mask, full_mask.float()) # If sum > 0, at least one atom in that collapsed position was True collapsed = collapsed_sum > 0 return collapsed def _collapse_values(self, full_values: torch.Tensor) -> torch.Tensor: """ Legacy collapse function - redirects to vectorized version. Kept for backward compatibility. """ return self._collapse_values_vectorized(full_values) def _collapse_mask(self, full_mask: torch.Tensor) -> torch.Tensor: """ Legacy collapse mask function - redirects to vectorized version. Kept for backward compatibility. """ return self._collapse_mask_vectorized(full_mask) def _expand_values(self, collapsed_values: torch.Tensor) -> torch.Tensor: """ Expand collapsed storage to full tensor using expansion mask. Parameters ---------- collapsed_values : torch.Tensor Tensor in collapsed space. Returns ------- torch.Tensor Tensor in full space (one value per atom). """ return collapsed_values[self.expansion_mask]
[docs] def forward(self) -> torch.Tensor: """ Reconstruct full occupancy tensor with sigmoid and altloc constraints. For alternative conformations, applies sigmoid then normalizes within each group to enforce sum-to-1 constraint. Returns ------- torch.Tensor Full occupancy tensor with values in [0, 1] and shape (n_atoms,). """ # Get collapsed logit values (combining fixed and refinable) # Uses pre-computed integer indices to avoid boolean indexing GPU sync result = self.fixed_values.clone() if self._has_refinable and self.refinable_params.numel() > 0: result[self._refinable_indices] = self.refinable_params # Apply sigmoid transformation to get raw occupancies if self.use_sigmoid: collapsed_occs = torch.sigmoid(result) else: collapsed_occs = result.clone() # Handle linked occupancies: normalize within each altloc group # Process each group size separately (2-way, 3-way, etc.) if hasattr(self, "linked_occ_sizes") and len(self.linked_occ_sizes) > 0: # Start with a copy of collapsed_occs that we'll update updated_occs = collapsed_occs.clone() for n_conf in self.linked_occ_sizes: # Get the tensor of linked indices: shape (N_groups, n_conf) linked_indices = getattr(self, f"linked_occ_{n_conf}") # Gather occupancies for all linked groups: shape (N_groups, n_conf) linked_occs = collapsed_occs[linked_indices] # Normalize: divide each by the sum across conformations # Shape: (N_groups, 1) for broadcasting sums = linked_occs.sum(dim=1, keepdim=True).clamp(min=1e-10) normalized_occs = linked_occs / sums # this should be vectorized assignment back to updated_occs indices_flat = linked_indices.flatten() occs_flat = normalized_occs.flatten() updated_occs[indices_flat] = occs_flat collapsed_occs = updated_occs # Expand to full space full_occs = self._expand_values(collapsed_occs) return full_occs.contiguous()
def _set_values(self, key, value: torch.Tensor) -> None: """ Internal method to set occupancy values with sigmoid reparameterization. Values are provided in NORMAL space (occupancies in [0, 1]) and automatically converted to logit-space for internal storage. Parameters ---------- key : indexing key Index specification (in full atom space). value : torch.Tensor Occupancy values to assign. Must be in [0, 1]. Raises ------ ValueError If any values are not in [0, 1]. Notes ----- This operates in FULL atom space. Values are collapsed to internal storage using the sharing groups. For atoms that share occupancies, the value is averaged across the group. """ # Validate values are in valid range if self.use_sigmoid: if (value < 0).any() or (value > 1).any(): raise ValueError("Occupancy values must be in range [0, 1]") # Get current full occupancies current_full = self.forward().detach() # Update in full space current_full[key] = value # Convert to logit space if self.use_sigmoid: clamped_values = torch.clamp(current_full, min=1e-6, max=1 - 1e-6) logit_values = torch.logit(clamped_values) else: logit_values = current_full.clone() # Collapse to internal storage collapsed_logits = self._collapse_values_vectorized(logit_values) # Update fixed_values buffer self.fixed_values = collapsed_logits.clone() # Re-extract refinable parameters if self.refinable_mask.any(): new_refinable = collapsed_logits[self.refinable_mask].clone() self.refinable_params = nn.Parameter( new_refinable, requires_grad=self.refinable_params.requires_grad ) @property def shape(self): """Return the shape of the FULL tensor (not collapsed).""" return (self._full_shape,) @property def collapsed_shape(self): """Return the shape of the collapsed internal storage.""" return (self._collapsed_shape,)
[docs] def clamp( self, min_value: float = 0.0, max_value: float = 1.0 ) -> "OccupancyTensor": """ Clamp occupancy values to specified range and return a new OccupancyTensor. Parameters ---------- min_value : float, optional Minimum occupancy value. Default is 0.0. max_value : float, optional Maximum occupancy value. Default is 1.0. Returns ------- OccupancyTensor New OccupancyTensor with clamped values. """ # Get current occupancy values in full space current_occ = self.forward().detach() # Clamp in occupancy space clamped_occ = torch.clamp(current_occ, min=min_value, max=max_value) # Reconstruct refinable mask in full space full_refinable_mask = self._expand_values(self.refinable_mask.float()).bool() # Create new OccupancyTensor new_occ = OccupancyTensor( initial_values=clamped_occ, sharing_groups=self.expansion_mask.clone(), refinable_mask=full_refinable_mask, requires_grad=self.refinable_params.requires_grad, dtype=self.dtype, device=self.device, name=self.name, use_sigmoid=self.use_sigmoid, ) return new_occ
[docs] def set_group_occupancy(self, group_idx: int, value: float): """ Set the occupancy for all atoms in a specific collapsed group. Parameters ---------- group_idx : int Collapsed index of the group. value : float Occupancy value to set (must be in [0, 1]). Raises ------ ValueError If group_idx is out of range or value is not in [0, 1]. """ # to fix numpy usage import numpy as np import warnings warnings.warn( "Using numpy inside torchref/model/parameter_wrappers.py, @Peter please fix", UserWarning, ) if group_idx < 0 or group_idx >= self._collapsed_shape: raise ValueError(f"Invalid group index {group_idx}") if value < 0 or value > 1: raise ValueError(f"Occupancy value must be in [0, 1], got {value}") # Convert value to logit space clamped_value = np.clip(value, 1e-6, 1 - 1e-6) logit_value = np.log(clamped_value / (1 - clamped_value)) logit_tensor = torch.tensor(logit_value, dtype=self.dtype, device=self.device) # The group occupies collapsed_idx = group_idx (groups are first in collapsed storage) collapsed_idx = group_idx # Get current collapsed logits result = self.fixed_values.clone() result[self.refinable_mask] = self.refinable_params.data # Update the collapsed value for this group result[collapsed_idx] = logit_tensor # Update fixed values and refinable params self.fixed_values = result.clone().detach() if self.refinable_mask[collapsed_idx]: # This group is refinable, update refinable params self.refinable_params.data = result[self.refinable_mask].clone()
[docs] def get_group_occupancy(self, group_idx: int) -> float: """ Get the current occupancy value for a collapsed group. Parameters ---------- group_idx : int Collapsed index of the group. Returns ------- float Current occupancy value for the group. Raises ------ ValueError If group_idx is out of range. """ if group_idx < 0 or group_idx >= self._collapsed_shape: raise ValueError(f"Invalid group index {group_idx}") # Get current occupancies in full space occupancies = self.forward() # Find first atom that maps to this collapsed index atom_idx = (self.expansion_mask == group_idx).nonzero()[0].item() return occupancies[atom_idx].item()
[docs] def freeze(self, mask: Optional[torch.Tensor] = None): """ Freeze occupancy parameters, making them non-refinable. The mask is supplied in UNCOMPRESSED (full atom) form but freezing operates on the COMPRESSED data structure. This method handles the conversion. Parameters ---------- mask : torch.Tensor, optional Boolean mask in FULL (uncompressed) atom space indicating which atoms to freeze. If None, freeze all parameters. Shape must be (n_atoms,). Notes ----- If ANY atom in a sharing group is frozen, the ENTIRE group is frozen because all atoms in a group share the same compressed parameter. Examples -------- :: # Freeze atoms 0-10 (in full atom space) freeze_mask = torch.zeros(n_atoms, dtype=torch.bool) freeze_mask[0:11] = True occ.freeze(freeze_mask) # Freeze all atoms occ.freeze() """ if mask is None: # Freeze all - set refinable_mask to all False mask = torch.ones(self._full_shape, dtype=torch.bool, device=self.device) else: # Validate mask shape if mask.shape[0] != self._full_shape: raise ValueError( f"Freeze mask must have shape ({self._full_shape},) to match full atom space, " f"got shape {mask.shape}" ) mask = mask.to(device=self.device, dtype=torch.bool) # Collapse the freeze mask to compressed space # If ANY atom in a group should be frozen, the group is frozen collapsed_freeze_mask = self._collapse_mask_vectorized(mask) # Get current full state (collapsed logits) current_logits = self.fixed_values.clone() current_logits[self.refinable_mask] = self.refinable_params.data # Update masks: positions to freeze become non-refinable new_refinable_mask = self.refinable_mask & ~collapsed_freeze_mask # Update fixed values with current state self.fixed_values = current_logits.clone().detach() # Update refinable params - only keep parameters that are still refinable if new_refinable_mask.any(): new_refinable_values = current_logits[new_refinable_mask].clone().detach() self.refinable_params = nn.Parameter( new_refinable_values, requires_grad=self.refinable_params.requires_grad ) else: # All parameters frozen - create empty parameter self.refinable_params = nn.Parameter( torch.empty(0, dtype=self.dtype, device=self.device), requires_grad=False, ) # Update masks self.refinable_mask = new_refinable_mask self.fixed_mask = ~new_refinable_mask # Rebuild index cache after mask change self._build_index_cache()
[docs] def unfreeze(self, mask: Optional[torch.Tensor] = None): """ Unfreeze occupancy parameters, making them refinable. The mask is supplied in UNCOMPRESSED (full atom) form but unfreezing operates on the COMPRESSED data structure. This method handles the conversion. Parameters ---------- mask : torch.Tensor, optional Boolean mask in FULL (uncompressed) atom space indicating which atoms to unfreeze. If None, unfreeze all parameters. Shape must be (n_atoms,). Notes ----- If ANY atom in a sharing group is unfrozen, the ENTIRE group becomes refinable because all atoms in a group share the same compressed parameter. Examples -------- :: # Unfreeze atoms 100-200 (in full atom space) unfreeze_mask = torch.zeros(n_atoms, dtype=torch.bool) unfreeze_mask[100:201] = True occ.unfreeze(unfreeze_mask) # Unfreeze all atoms occ.unfreeze() """ if mask is None: # Unfreeze all - set refinable_mask to all True mask = torch.ones(self._full_shape, dtype=torch.bool, device=self.device) else: # Validate mask shape if mask.shape[0] != self._full_shape: raise ValueError( f"Unfreeze mask must have shape ({self._full_shape},) to match full atom space, " f"got shape {mask.shape}" ) mask = mask.to(device=self.device, dtype=torch.bool) # Collapse the unfreeze mask to compressed space # If ANY atom in a group should be unfrozen, the group becomes refinable collapsed_unfreeze_mask = self._collapse_mask_vectorized(mask) # Get current full state (collapsed logits) current_logits = self.fixed_values.clone() if self.refinable_mask.any(): current_logits[self.refinable_mask] = self.refinable_params.data # Update masks: positions to unfreeze become refinable new_refinable_mask = self.refinable_mask | collapsed_unfreeze_mask # Update fixed values with current state self.fixed_values = current_logits.clone().detach() # Update refinable params - include newly unfrozen parameters if new_refinable_mask.any(): new_refinable_values = current_logits[new_refinable_mask].clone().detach() self.refinable_params = nn.Parameter( new_refinable_values, requires_grad=True, # Unfrozen parameters should have gradients ) else: # No refinable parameters self.refinable_params = nn.Parameter( torch.empty(0, dtype=self.dtype, device=self.device), requires_grad=False, ) # Update masks self.refinable_mask = new_refinable_mask self.fixed_mask = ~new_refinable_mask # Rebuild index cache after mask change self._build_index_cache()
[docs] def freeze_all(self): """ Freeze all occupancy parameters. Convenience method equivalent to freeze(None). """ self.freeze(None)
[docs] def unfreeze_all(self): """ Unfreeze all occupancy parameters. Convenience method equivalent to unfreeze(None). """ self.unfreeze(None)
[docs] def get_refinable_atoms(self) -> torch.Tensor: """ Get a boolean mask in FULL atom space indicating refinable atoms. Returns ------- torch.Tensor Boolean tensor of shape (n_atoms,) where True indicates the atom's occupancy is refinable (though it shares with others in its group). """ return self._expand_values(self.refinable_mask.float()).bool()
[docs] def get_frozen_atoms(self) -> torch.Tensor: """ Get a boolean mask in FULL atom space indicating frozen atoms. Returns ------- torch.Tensor Boolean tensor of shape (n_atoms,) where True indicates the atom's occupancy is frozen. """ return self._expand_values(self.fixed_mask.float()).bool()
[docs] def get_refinable_count(self) -> int: """ Get the number of refinable parameters in COMPRESSED space. This is the number of refinable groups, not atoms. Use get_refinable_atoms().sum() to get the number of refinable atoms. Returns ------- int Number of refinable compressed parameters. """ return self.refinable_mask.sum().item()
[docs] def get_fixed_count(self) -> int: """ Get the number of fixed parameters in COMPRESSED space. This is the number of fixed groups, not atoms. Use get_frozen_atoms().sum() to get the number of frozen atoms. Returns ------- int Number of fixed compressed parameters. """ return self.fixed_mask.sum().item()
[docs] def update_refinable_mask( self, new_mask: torch.Tensor, in_compressed_space: bool = False ): """ Directly update the refinable mask with a new mask. Allows more direct control over which parameters are refinable, compared to freeze/unfreeze which modify the existing state. Parameters ---------- new_mask : torch.Tensor Boolean tensor indicating which parameters should be refinable. If in_compressed_space=False: shape (n_atoms,) in full atom space. If in_compressed_space=True: shape (n_groups,) in compressed space. in_compressed_space : bool, optional If True, new_mask is in compressed space. If False (default), new_mask is in full atom space and will be collapsed. Examples -------- Full atom space:: atom_mask = torch.zeros(n_atoms, dtype=torch.bool) atom_mask[:100] = True occ.update_refinable_mask(atom_mask, in_compressed_space=False) Compressed space:: group_mask = torch.zeros(n_groups, dtype=torch.bool) group_mask[::2] = True occ.update_refinable_mask(group_mask, in_compressed_space=True) """ # Validate and convert mask if not in_compressed_space: # Mask is in full atom space, need to collapse if new_mask.shape[0] != self._full_shape: raise ValueError( f"Mask in full atom space must have shape ({self._full_shape},), " f"got shape {new_mask.shape}" ) new_mask = new_mask.to(device=self.device, dtype=torch.bool) collapsed_mask = self._collapse_mask_vectorized(new_mask) else: # Mask is already in compressed space if new_mask.shape[0] != self._collapsed_shape: raise ValueError( f"Mask in compressed space must have shape ({self._collapsed_shape},), " f"got shape {new_mask.shape}" ) new_mask = new_mask.to(device=self.device, dtype=torch.bool) collapsed_mask = new_mask # Get current full state (collapsed logits) current_logits = self.fixed_values.clone() if self.refinable_mask.any(): current_logits[self.refinable_mask] = self.refinable_params.data # Update fixed values with current state self.fixed_values = current_logits.clone().detach() # Create new refinable params based on new mask if collapsed_mask.any(): new_refinable_values = current_logits[collapsed_mask].clone().detach() self.refinable_params = nn.Parameter( new_refinable_values, requires_grad=True ) else: # No refinable parameters self.refinable_params = nn.Parameter( torch.empty(0, dtype=self.dtype, device=self.device), requires_grad=False, ) # Update masks self.refinable_mask = collapsed_mask self.fixed_mask = ~collapsed_mask # Rebuild index cache after mask change self._build_index_cache()
[docs] @staticmethod def from_residue_groups( initial_values: torch.Tensor, pdb_dataframe, refinable_mask: Optional[torch.Tensor] = None, **kwargs, ) -> "OccupancyTensor": """ Create an OccupancyTensor where all atoms in each residue share occupancy. Common use case where all atoms in a residue should have the same occupancy. Parameters ---------- initial_values : torch.Tensor Initial occupancy values for all atoms. pdb_dataframe : pandas.DataFrame DataFrame with PDB data (must have 'resname', 'resseq', 'chainid'). refinable_mask : torch.Tensor, optional Mask for refinable atoms. **kwargs Additional arguments passed to OccupancyTensor constructor. Returns ------- OccupancyTensor OccupancyTensor with residue-based sharing groups. """ # Group atoms by residue grouped = pdb_dataframe.groupby(["resname", "resseq", "chainid", "altloc"]) n_atoms = len(initial_values) sharing_groups_tensor = torch.arange(n_atoms, dtype=torch.long) collapsed_idx = 0 for (resname, resseq, chainid, altloc), group in grouped: indices = group["index"].tolist() if len(indices) > 1: # Only create group if more than one atom sharing_groups_tensor[indices] = collapsed_idx collapsed_idx += 1 # Compact the indices unique_indices = torch.unique(sharing_groups_tensor, sorted=True) for new_idx, old_idx in enumerate(unique_indices): mask = sharing_groups_tensor == old_idx sharing_groups_tensor[mask] = new_idx return OccupancyTensor( initial_values=initial_values, sharing_groups=sharing_groups_tensor, refinable_mask=refinable_mask, name="occupancy", **kwargs, )
[docs] def copy(self) -> "OccupancyTensor": """ Create a deep copy of this OccupancyTensor. Creates a complete independent copy with all buffers and parameters, including sharing groups, altloc groups, and collapsed storage structures. Returns ------- OccupancyTensor New OccupancyTensor instance with copied data. """ # Get current occupancy values in normal space (full atom space) current_occ = self.forward().detach() # Reconstruct refinable mask in full space full_refinable_mask = self._expand_values(self.refinable_mask.float()).bool() # Reconstruct altloc groups from the linked_occ buffers altloc_groups = [] if hasattr(self, "linked_occ_sizes"): for n_conf in self.linked_occ_sizes: linked_indices = getattr( self, f"linked_occ_{n_conf}" ) # shape (N_groups, n_conf) # For each group of linked conformations for group_collapsed_indices in linked_indices: # Find all atoms that map to each collapsed index conf_atom_lists = [] for collapsed_idx in group_collapsed_indices: atom_indices = ( (self.expansion_mask == collapsed_idx) .nonzero(as_tuple=False) .squeeze(-1) ) conf_atom_lists.append(atom_indices.tolist()) altloc_groups.append(tuple(conf_atom_lists)) # Create new OccupancyTensor with the same configuration new_tensor = OccupancyTensor( initial_values=current_occ, sharing_groups=self.expansion_mask.clone(), altloc_groups=altloc_groups if altloc_groups else None, refinable_mask=full_refinable_mask, requires_grad=self.refinable_params.requires_grad, dtype=self.dtype, device=self.device, name=self._name, use_sigmoid=self.use_sigmoid, ) return new_tensor
def __repr__(self) -> str: name_str = f"'{self.name}', " if self.name is not None else "" n_groups = self._collapsed_shape return ( f"OccupancyTensor({name_str}shape={self.shape}, dtype={self.dtype}, " f"device={self.device}, refinable={self.get_refinable_count()}, " f"fixed={self.get_fixed_count()}, collapsed_groups={n_groups}, " f"use_sigmoid={self.use_sigmoid})" )
[docs] class PassThroughTensor(DeviceMixin, nn.Module): """ A simple parameter wrapper that passes the parameter through unchanged. Useful as a placeholder or for parameters that do not require any special handling. Parameters ---------- initial_values : torch.Tensor Initial tensor values. requires_grad : bool, optional Whether the parameter requires gradients. Default is True. dtype : torch.dtype, optional Data type of the tensor. device : torch.device, optional Device to place the tensor on. name : str, optional Optional name for the parameter. """
[docs] def __init__( self, initial_values: torch.Tensor, requires_grad: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, name: Optional[str] = None, ): """ Initialize the PassThroughTensor. Parameters ---------- initial_values : torch.Tensor Initial tensor values. requires_grad : bool, optional Whether the parameter requires gradients. Default is True. dtype : torch.dtype, optional Data type of the tensor. device : torch.device, optional Device to place the tensor on. name : str, optional Optional name for the parameter. """ super().__init__( initial_values=initial_values, requires_grad=requires_grad, dtype=dtype, device=device, name=name, )
[docs] def forward(self) -> torch.Tensor: """ Return the parameter value unchanged. Returns ------- torch.Tensor The parameter tensor. """ return self.param