Source code for torchref.utils.utils

from pathlib import Path
from typing import Any, Dict, List, Optional, Union, Tuple

import gemmi
import numpy as np
import pandas as pd
import torch
import json

from torchref.utils.device_mixin import DeviceMovementMixin


[docs] class ModuleReference: """ A wrapper class to hold references to PyTorch modules without registering them. When you assign a nn.Module to an attribute of another nn.Module, PyTorch automatically registers it as a submodule, which adds its parameters to the parent's parameter tree. This wrapper prevents tlog_normal_stdhat automatic registration. This is useful when you want to: - Hold references to modules without including their parameters - Avoid circular dependencies in the module tree - Reference external modules that should be managed separately Attributes ---------- _wrapped_module : torch.nn.Module The wrapped PyTorch module. Examples -------- :: model = MyModel() scaler = Scaler() scaler._model = ModuleReference(model) # Won't register as submodule # Access the module via .module property output = scaler._model.module(input_data) """
[docs] def __init__(self, module): """ Wrap a module to prevent automatic registration. Parameters ---------- module : torch.nn.Module The PyTorch module to wrap. """ # Store in __dict__ directly to avoid any attribute interception object.__setattr__(self, "_wrapped_module", module)
@property def module(self): """Access the wrapped module.""" return object.__getattribute__(self, "_wrapped_module")
[docs] def __getattr__(self, name): """Forward attribute access to the wrapped module.""" return getattr(self.module, name)
[docs] def __call__(self, *args, **kwargs): """Forward calls to the wrapped module.""" return self.module(*args, **kwargs)
def __repr__(self): return f"ModuleReference({self.module.__class__.__name__})"
[docs] class CIFReader: """ A dictionary-like reader for CIF/mmCIF files. Loops are stored as pandas DataFrames. Other data is stored in a hierarchical dictionary structure. Attributes ---------- data : dict Dictionary storing parsed CIF data. filepath : pathlib.Path or None Path to the loaded CIF file. """
[docs] def __init__(self, filepath: Optional[str] = None): """ Initialize CIF reader. Parameters ---------- filepath : str, optional Path to CIF file to load immediately. """ self.data = {} self.filepath = None if filepath: self.load(filepath)
[docs] def load(self, filepath: str): """ Load and parse a CIF file. Parameters ---------- filepath : str Path to CIF file. """ self.filepath = Path(filepath) with open(filepath, "r", encoding="utf-8", errors="ignore") as f: content = f.read() self._parse(content)
def _parse(self, content: str): """ Parse CIF file content. Parameters ---------- content : str String content of CIF file. """ lines = content.split("\n") i = 0 while i < len(lines): line = lines[i].strip() # Skip empty lines and comments if not line or line.startswith("#"): i += 1 continue # Check for data block (usually just one in mmCIF) if line.startswith("data_"): i += 1 continue # Check for loop if line.startswith("loop_"): i = self._parse_loop(lines, i + 1) continue # Parse single key-value pairs if line.startswith("_"): i = self._parse_keyvalue(lines, i) continue i += 1 def _parse_loop(self, lines: List[str], start_idx: int) -> int: """ Parse a loop structure into a pandas DataFrame. Parameters ---------- lines : list of str All lines of the file. start_idx : int Starting line index (after 'loop_'). Returns ------- int Index of the next line to process. """ # Collect column names columns = [] i = start_idx while i < len(lines): line = lines[i].strip() if not line or line.startswith("#"): i += 1 continue if line.startswith("_"): columns.append(line) i += 1 else: break if not columns: return i # Collect data rows data_rows = [] current_row = [] in_multiline = False multiline_value = [] while i < len(lines): line = lines[i] stripped = line.strip() # Check if we've reached the end of the loop if not in_multiline and ( not stripped or stripped.startswith("_") or stripped.startswith("loop_") or stripped.startswith("data_") ): if current_row: data_rows.append(current_row) break # Handle multiline strings (starting with semicolon) if not in_multiline and line.startswith(";"): in_multiline = True multiline_value = [line[1:]] # Remove leading semicolon i += 1 continue if in_multiline: if line.startswith(";"): # End of multiline string in_multiline = False current_row.append("\n".join(multiline_value)) multiline_value = [] # Check if row is complete if len(current_row) == len(columns): data_rows.append(current_row) current_row = [] else: multiline_value.append(line) i += 1 continue # Parse regular data line if stripped and not stripped.startswith("#"): tokens = self._tokenize_line(stripped) for token in tokens: current_row.append(token) # Check if row is complete if len(current_row) == len(columns): data_rows.append(current_row) current_row = [] i += 1 # Create DataFrame if data_rows: df = pd.DataFrame(data_rows, columns=columns) # Store in hierarchical dictionary based on category # Extract category from first column name (e.g., _atom_site.id -> atom_site) if columns: category = self._extract_category(columns[0]) if category: self.data[category] = df return i def _parse_keyvalue(self, lines: List[str], start_idx: int) -> int: """ Parse a single key-value pair. Parameters ---------- lines : list of str All lines of the file. start_idx : int Starting line index. Returns ------- int Index of the next line to process. """ line = lines[start_idx].strip() # Handle multiline values if start_idx + 1 < len(lines) and lines[start_idx + 1].startswith(";"): key = line value_lines = [] i = start_idx + 2 while i < len(lines): if lines[i].startswith(";"): break value_lines.append(lines[i]) i += 1 value = "\n".join(value_lines) self._store_keyvalue(key, value) return i + 1 # Handle single line key-value parts = line.split(None, 1) if len(parts) == 2: key, value = parts # Remove quotes if present if value.startswith("'") and value.endswith("'"): value = value[1:-1] elif value.startswith('"') and value.endswith('"'): value = value[1:-1] self._store_keyvalue(key, value) return start_idx + 1 def _tokenize_line(self, line: str) -> List[str]: """ Tokenize a data line, handling quoted strings. Parameters ---------- line : str Line to tokenize. Returns ------- list of str List of tokens. """ tokens = [] current_token = [] in_quotes = False quote_char = None i = 0 while i < len(line): char = line[i] # Handle quotes if char in ('"', "'") and not in_quotes: in_quotes = True quote_char = char i += 1 continue if char == quote_char and in_quotes: in_quotes = False quote_char = None if current_token: tokens.append("".join(current_token)) current_token = [] i += 1 continue # Handle whitespace outside quotes if char.isspace() and not in_quotes: if current_token: tokens.append("".join(current_token)) current_token = [] i += 1 continue current_token.append(char) i += 1 if current_token: tokens.append("".join(current_token)) return tokens def _extract_category(self, key: str) -> str: """ Extract category from a CIF key (e.g., '_atom_site.id' -> 'atom_site'). Parameters ---------- key : str CIF key. Returns ------- str Category name. """ if key.startswith("_"): key = key[1:] if "." in key: return key.split(".")[0] return key def _store_keyvalue(self, key: str, value: str): """ Store a key-value pair in the hierarchical dictionary. Parameters ---------- key : str CIF key (e.g., '_entry.id'). value : str Value to store. """ # Extract category and attribute if key.startswith("_"): key = key[1:] if "." in key: category, attribute = key.split(".", 1) if category not in self.data: self.data[category] = {} if isinstance(self.data[category], dict): self.data[category][attribute] = value else: self.data[key] = value
[docs] def write(self, filepath: str): """ Write the CIF data back to a file. Parameters ---------- filepath : str Output file path. """ with open(filepath, "w") as f: f.write("data_structure\n") f.write("#\n") # Write single key-value pairs first for category, content in sorted(self.data.items()): if isinstance(content, dict): for key, value in sorted(content.items()): # Handle multiline values if "\n" in str(value): f.write(f"_{category}.{key}\n") f.write(";\n") f.write(str(value)) f.write("\n;\n") else: # Quote values with spaces if " " in str(value): f.write(f"_{category}.{key} '{value}'\n") else: f.write(f"_{category}.{key} {value}\n") f.write("#\n") # Write loops (DataFrames) for category, content in sorted(self.data.items()): if isinstance(content, pd.DataFrame): f.write("loop_\n") # Write column names for col in content.columns: f.write(f"{col}\n") # Write data rows for _, row in content.iterrows(): row_values = [] for val in row: val_str = str(val) # Quote values with spaces or special characters if " " in val_str or any(c in val_str for c in ['"', "'"]): row_values.append(f"'{val_str}'") else: row_values.append(val_str) f.write(" ".join(row_values) + "\n") f.write("#\n")
# Dictionary-like interface
[docs] def __getitem__(self, key: str) -> Union[pd.DataFrame, Dict, Any]: """Get item by key.""" return self.data[key]
[docs] def __setitem__(self, key: str, value: Union[pd.DataFrame, Dict, Any]): """Set item by key.""" self.data[key] = value
[docs] def __contains__(self, key: str) -> bool: """Check if key exists.""" return key in self.data
[docs] def __len__(self) -> int: """Return number of top-level categories.""" return len(self.data)
[docs] def keys(self): """Return dictionary keys.""" return self.data.keys()
[docs] def values(self): """Return dictionary values.""" return self.data.values()
[docs] def items(self): """Return dictionary items.""" return self.data.items()
[docs] def get(self, key: str, default=None): """Get item with default value.""" return self.data.get(key, default)
[docs] def __repr__(self) -> str: """String representation.""" categories = list(self.keys()) loops = [k for k, v in self.items() if isinstance(v, pd.DataFrame)] dicts = [k for k, v in self.items() if isinstance(v, dict)] return ( f"CIFReader(categories={len(categories)}, " f"loops={len(loops)}, " f"key-value_groups={len(dicts)})" )
[docs] def summary(self): """Print a summary of the CIF contents.""" print(f"CIF File: {self.filepath}") print(f"Total categories: {len(self.data)}") print("\nLoops (DataFrames):") for key, value in sorted(self.items()): if isinstance(value, pd.DataFrame): print(f" {key}: {len(value)} rows × {len(value.columns)} columns") print("\nKey-Value Groups (Dictionaries):") for key, value in sorted(self.items()): if isinstance(value, dict): print(f" {key}: {len(value)} items")
[docs] def save_map(array, cell, filename): """ Save a 3D map to a CCP4 file. Parameters ---------- array : numpy.ndarray or torch.Tensor 3D array representing the map. cell : list, tuple, numpy.ndarray, torch.Tensor, or gemmi.UnitCell Unit cell parameters [a, b, c, alpha, beta, gamma]. filename : str Output CCP4 file name. Returns ------- bool True if save was successful. """ if isinstance(array, torch.Tensor): np_map = array.detach().cpu().numpy().astype(np.float32) else: np_map = array.astype(np.float32) if isinstance(cell, gemmi.UnitCell): cell = cell.parameters elif isinstance(cell, np.ndarray): cell = cell.tolist() elif isinstance(cell, list): cell = cell elif isinstance(cell, tuple): cell = list(cell) elif isinstance(cell, torch.Tensor): cell = cell.tolist() map_ccp = gemmi.Ccp4Map() map_ccp.grid = gemmi.FloatGrid( np_map, gemmi.UnitCell(*cell), gemmi.find_spacegroup_by_name("P1") ) map_ccp.setup(0.0) map_ccp.update_ccp4_header() map_ccp.write_ccp4_map(filename) print("Map saved successfully") return True
import torch.nn as nn
[docs] class TensorDict(nn.Module): """ A dictionary-like container for PyTorch tensors that: - Supports standard dict syntax - Automatically moves with the module - Registers tensors as buffers so they are included in state_dict """
[docs] def __init__(self, initial_dict: Optional[Dict[str, torch.Tensor]] = None): super().__init__() self._keys = [] if initial_dict: for k, v in initial_dict.items(): self[k] = v
def __setitem__(self, key: str, tensor: torch.Tensor): name = f"_buf_{key}" if not hasattr(self, name): # Register as buffer self.register_buffer(name, tensor) self._keys.append(key) else: existing = getattr(self, name) if existing.shape == tensor.shape: # Update existing buffer in-place (same shape) existing.data.copy_(tensor) else: # Shape changed - re-register the buffer with new tensor delattr(self, name) self.register_buffer(name, tensor) def __getitem__(self, key: str) -> torch.Tensor: name = f"_buf_{key}" if not hasattr(self, name): raise KeyError(key) return getattr(self, name) def __contains__(self, key: str): return key in self._keys
[docs] def keys(self): return self._keys.copy()
[docs] def values(self): return [getattr(self, f"_buf_{k}") for k in self._keys]
[docs] def items(self): return [(k, getattr(self, f"_buf_{k}")) for k in self._keys]
def __len__(self): return len(self._keys) def __repr__(self): return ( "TensorDict({" + ", ".join(f'{k}: {getattr(self, f"_buf_{k}")}' for k in self._keys) + "}})" ) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): """Override to dynamically register buffers during loading.""" local_keys = [k for k in state_dict.keys() if k.startswith(prefix + "_buf_")] for key in local_keys: buffer_name = key[len(prefix) :] original_key = buffer_name[5:] # remove "_buf_" if not hasattr(self, buffer_name): tensor = state_dict[key] self.register_buffer(buffer_name, torch.zeros_like(tensor)) self._keys.append(original_key) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, )
[docs] class TensorMasks(DeviceMovementMixin, dict): """ A dictionary for managing boolean mask tensors with device support. This is a lightweight dict subclass that: - Ensures all tensors are boolean dtype - Supports device movement via to(), cuda(), cpu() - Provides combined mask via __call__() Parameters ---------- data : dict, optional Initial mask data. device : str or torch.device, optional Device for tensors. Defaults to the configured device.current. Examples -------- :: masks = TensorMasks(device='cuda') masks['valid'] = torch.ones(100, dtype=torch.bool) masks['rfree'] = rfree_flags > 0 combined = masks() # Get combined mask (AND of all) masks.cpu() # Move all to CPU """
[docs] def __init__(self, data=None, device=None): super().__init__() if device is None: from torchref.config import get_default_device device = get_default_device() self.device = torch.device(device) self._cache = None self._updated = True # Initialize with provided data if data: for k, v in data.items(): self[k] = v
[docs] def __setitem__(self, key: str, tensor: torch.Tensor): """Set mask tensor, ensuring boolean dtype and correct device.""" if tensor is not None: if tensor.dtype != torch.bool: raise ValueError( f"Mask '{key}' must be boolean dtype, got {tensor.dtype}" ) if tensor.sum() == 0: raise ValueError(f"Mask '{key}' cannot be all False, this would mask all data.") tensor = tensor.to(self.device) super().__setitem__(key, tensor) self._updated = True
def _apply(self, fn): """Move mask tensors stored as ``dict`` items and invalidate the cache. ``TensorMasks`` is a ``dict`` subclass — its mask tensors live in the dict's own storage, **not** in ``self.__dict__`` — so the standard :class:`DeviceMixin` ``__dict__`` walk would otherwise miss them and only move the cached combined mask, leaving the per-key masks on the previous device. """ # Walk the dict storage and move each mask tensor. for k in list(self.keys()): v = self[k] if isinstance(v, torch.Tensor): dict.__setitem__(self, k, fn(v)) # Invalidate the combined-mask cache so the next call to ``self()`` # recomputes from the moved masks rather than returning the stale # combined tensor. self._cache = None self._updated = True # Refresh the ``device`` tracker so future ``__setitem__`` calls # (which migrate incoming tensors to ``self.device``) land correctly. for v in self.values(): if isinstance(v, torch.Tensor): self.device = v.device break return self
[docs] def reset_cache(self) -> None: """Invalidate the cached combined mask.""" self._cache = None self._updated = True
[docs] def __call__(self) -> torch.Tensor: """ Return combined mask (AND of all masks). Returns ------- torch.Tensor Combined boolean mask, or None if no masks. """ if not self: return None if self._updated or self._cache is None: self._cache = self._get_combined_mask() self._updated = False return self._cache
def _get_combined_mask(self) -> torch.Tensor: """Compute combined mask using logical AND.""" masks = [v for v in self.values() if v is not None] if not masks: return None combined = masks[0].clone() for m in masks[1:]: combined &= m return combined def __repr__(self): mask_info = ", ".join( f"'{k}': shape={v.shape}" for k, v in self.items() if v is not None ) return f"TensorMasks({{{mask_info}}}, device={self.device})"
[docs] def sanitize_pdb_dataframe(pdb: pd.DataFrame, verbose: int = 0) -> pd.DataFrame: """ Sanitize a PDB DataFrame to ensure unique atom identifiers. This function fixes common issues in PDB/CIF files: 1. HETATM records (especially waters) with duplicate resseq values (e.g., all 0) 2. Residue names longer than 3 characters (truncates to 3) 3. Ensures unique (chainid, resseq, name, altloc) combinations Parameters ---------- pdb : pandas.DataFrame DataFrame with PDB data (must have columns: ATOM, chainid, resseq, name, altloc, resname, serial). verbose : int, default 0 Verbosity level (0=silent, 1=info, 2=debug). Returns ------- pandas.DataFrame Sanitized DataFrame with unique atom identifiers. Examples -------- :: from torchref.model import Model from torchref.utils import sanitize_pdb_dataframe model = Model() model.load_cif('structure.cif') model.pdb = sanitize_pdb_dataframe(model.pdb, verbose=1) """ pdb = pdb.copy() if verbose > 0: print("Sanitizing PDB DataFrame...") print(f" Initial atoms: {len(pdb)}") # 1. Standardize residue names to max 3 characters long_resnames = pdb["resname"].str.len() > 3 if long_resnames.any(): n_long = long_resnames.sum() if verbose > 0: unique_long = pdb.loc[long_resnames, "resname"].unique() print( f" Truncating {n_long} atoms with resname > 3 chars: {unique_long[:5]}" ) pdb.loc[long_resnames, "resname"] = pdb.loc[long_resnames, "resname"].str[:3] # 2. Fix duplicate atom identifiers by reassigning resseq # Check for duplicates dup_mask = pdb.duplicated( subset=["chainid", "resseq", "name", "altloc"], keep=False ) if dup_mask.any(): n_dup = dup_mask.sum() if verbose > 0: print(f" Found {n_dup} atoms with duplicate identifiers") # Group by (chainid, resname, ATOM) to handle each group separately # This ensures we only renumber within the same molecule type and chain for (chainid, resname, atom_type), group in pdb.groupby( ["chainid", "resname", "ATOM"] ): group_indices = group.index # Check if this group has duplicates group_dup_mask = group.duplicated( subset=["chainid", "resseq", "name", "altloc"], keep=False ) if group_dup_mask.any(): # Find the maximum resseq in this chain to start numbering from there chain_data = pdb[pdb["chainid"] == chainid] max_resseq = chain_data["resseq"].max() # Start numbering from max_resseq + 1 new_resseq_start = ( max_resseq + 1 if pd.notna(max_resseq) and max_resseq > 0 else 1 ) # Assign new sequential resseq values to all atoms in this group # Group by (serial) to keep atoms of the same residue together unique_serials = group["serial"].unique() residue_counter = new_resseq_start for serial in unique_serials: serial_mask = pdb["serial"] == serial pdb.loc[serial_mask, "resseq"] = residue_counter residue_counter += 1 if verbose > 1: n_fixed = len(unique_serials) print( f" Fixed {n_fixed} {resname} residues in chain {chainid} (resseq {new_resseq_start}-{residue_counter-1})" ) # Verify duplicates are fixed final_dup_mask = pdb.duplicated( subset=["chainid", "resseq", "name", "altloc"], keep=False ) if final_dup_mask.any(): remaining_dups = final_dup_mask.sum() if verbose > 0: print( f" WARNING: Still have {remaining_dups} duplicate identifiers after sanitization" ) dups = pdb[final_dup_mask].sort_values(["chainid", "resseq", "name"]) print( dups[ [ "ATOM", "serial", "name", "resname", "chainid", "resseq", "altloc", ] ].head(10) ) else: if verbose > 0: print(" ✓ All duplicate identifiers resolved") else: if verbose > 0: print(" ✓ No duplicate atom identifiers found") if verbose > 0: print(f" Final atoms: {len(pdb)}") return pdb
def _parse_with_parentheses( selection_string: str, pdb_df: pd.DataFrame ) -> torch.Tensor: """ Helper function to handle parentheses in selection strings. Recursively evaluates innermost parentheses first. """ import re # Find innermost parentheses while True: match = re.search(r"\(([^()]+)\)", selection_string) if not match: break # Evaluate the innermost parenthesized expression inner = match.group(1) inner_mask = _parse_without_parentheses(inner, pdb_df) # Replace with a placeholder that we'll substitute back # Use a unique placeholder that won't appear in normal selection placeholder = f"__MASK_{id(inner_mask)}__" selection_string = ( selection_string[: match.start()] + placeholder + selection_string[match.end() :] ) # Store the mask result in a temporary global dict # (not ideal but works for this recursive evaluation) if not hasattr(_parse_with_parentheses, "_mask_cache"): _parse_with_parentheses._mask_cache = {} _parse_with_parentheses._mask_cache[placeholder] = inner_mask # Now parse the expression without parentheses, substituting cached masks return _parse_without_parentheses(selection_string, pdb_df) def _parse_without_parentheses( selection_string: str, pdb_df: pd.DataFrame ) -> torch.Tensor: """ Parse selection string without parentheses. Handles logical operators and basic keywords. """ import re selection_string = selection_string.strip() if not selection_string: raise ValueError("Selection string cannot be empty") # Check if this is a cached mask placeholder if selection_string.startswith("__MASK_") and selection_string.endswith("__"): if hasattr(_parse_with_parentheses, "_mask_cache"): return _parse_with_parentheses._mask_cache.get( selection_string, torch.ones(len(pdb_df), dtype=torch.bool) ) return torch.ones(len(pdb_df), dtype=torch.bool) # Handle "all" keyword if selection_string.lower() == "all": return torch.ones(len(pdb_df), dtype=torch.bool) # Parse logical operators (or, and, not) with proper precedence # Priority: not > and > or # First, handle "or" (lowest precedence) if " or " in selection_string.lower(): parts = re.split(r"\s+or\s+", selection_string, flags=re.IGNORECASE) masks = [_parse_without_parentheses(part.strip(), pdb_df) for part in parts] result = masks[0] for mask in masks[1:]: result = result | mask return result # Then, handle "and" if " and " in selection_string.lower(): parts = re.split(r"\s+and\s+", selection_string, flags=re.IGNORECASE) masks = [_parse_without_parentheses(part.strip(), pdb_df) for part in parts] result = masks[0] for mask in masks[1:]: result = result & mask return result # Then, handle "not" if selection_string.lower().startswith("not "): inner_selection = selection_string[4:].strip() return ~_parse_without_parentheses(inner_selection, pdb_df) # Now handle individual selection keywords parts = selection_string.split(None, 1) if len(parts) < 2: raise ValueError(f"Invalid selection syntax: '{selection_string}'") keyword, value = parts[0].lower(), parts[1] # Initialize mask as all False mask = torch.zeros(len(pdb_df), dtype=torch.bool) if keyword == "chain": # Select by chain ID chain_id = value.strip() selected = pdb_df["chainid"] == chain_id mask = torch.tensor(selected.values, dtype=torch.bool) elif keyword == "resseq": # Select by residue sequence number or range if ":" in value: # Range selection start, end = value.split(":") start, end = int(start.strip()), int(end.strip()) selected = (pdb_df["resseq"] >= start) & (pdb_df["resseq"] <= end) else: # Single residue resseq_num = int(value.strip()) selected = pdb_df["resseq"] == resseq_num mask = torch.tensor(selected.values, dtype=torch.bool) elif keyword == "resname": # Select by residue name resname = value.strip().upper() selected = pdb_df["resname"].str.upper() == resname mask = torch.tensor(selected.values, dtype=torch.bool) elif keyword == "name": # Select by atom name atom_name = value.strip().upper() selected = pdb_df["name"].str.upper() == atom_name mask = torch.tensor(selected.values, dtype=torch.bool) elif keyword == "element": # Select by element element = value.strip().capitalize() selected = pdb_df["element"].str.capitalize() == element mask = torch.tensor(selected.values, dtype=torch.bool) elif keyword == "altloc": # Select by alternate location altloc = value.strip() selected = pdb_df["altloc"] == altloc mask = torch.tensor(selected.values, dtype=torch.bool) else: raise ValueError(f"Unknown selection keyword: '{keyword}'") return mask
[docs] def parse_phenix_selection(selection_string: str, pdb_df: pd.DataFrame) -> torch.Tensor: """ Parse Phenix-style atom selection syntax and return a boolean mask. Supports common Phenix selection keywords: - chain <id>: Select atoms by chain ID (e.g., "chain A") - resseq <num>: Select atoms by residue sequence number (e.g., "resseq 10") - resseq <start>:<end>: Select residue range (e.g., "resseq 10:20") - resname <name>: Select atoms by residue name (e.g., "resname ALA") - name <atom>: Select atoms by atom name (e.g., "name CA") - element <elem>: Select atoms by element (e.g., "element C") - altloc <id>: Select atoms by alternate location (e.g., "altloc A") - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection of selections - <sel1> or <sel2>: Union of selections - Parentheses for grouping: (selection) Parameters ---------- selection_string : str Phenix-style selection string. pdb_df : pandas.DataFrame DataFrame containing atomic data with columns: 'chainid', 'resseq', 'resname', 'name', 'element', 'altloc'. Returns ------- torch.Tensor Boolean tensor of shape (n_atoms,) where True indicates selected atoms. Raises ------ ValueError If selection syntax is invalid. Examples -------- :: # Select chain A mask = parse_phenix_selection("chain A", pdb_df) # Select residues 10-20 in chain A mask = parse_phenix_selection("chain A and resseq 10:20", pdb_df) # Select all CA atoms mask = parse_phenix_selection("name CA", pdb_df) # Select backbone atoms mask = parse_phenix_selection("name CA or name C or name N or name O", pdb_df) # Select everything except water mask = parse_phenix_selection("not resname HOH", pdb_df) # Use parentheses for grouping mask = parse_phenix_selection("chain A and (name CA or name CB)", pdb_df) """ # Clear any cached masks from previous calls if hasattr(_parse_with_parentheses, "_mask_cache"): _parse_with_parentheses._mask_cache.clear() # Check if there are parentheses if "(" in selection_string: return _parse_with_parentheses(selection_string, pdb_df) else: return _parse_without_parentheses(selection_string, pdb_df)
[docs] def create_selection_mask( selection_string: str, pdb_df: pd.DataFrame, current_mask: Optional[torch.Tensor] = None, mode: str = "set", ) -> torch.Tensor: """ Create or modify a refinable mask based on a Phenix-style selection. This function allows you to update refinable masks by selecting specific atoms using Phenix-style syntax. You can either replace the current mask, add to it, or remove from it. Parameters ---------- selection_string : str Phenix-style selection string. pdb_df : pandas.DataFrame DataFrame containing atomic data. current_mask : torch.Tensor, optional Current refinable mask. If None, starts with all False. mode : str, default 'set' How to combine with current mask: - 'set': Replace mask with selection (default) - 'add': Add selection to current mask (OR operation) - 'remove': Remove selection from current mask (AND NOT operation) Returns ------- torch.Tensor Updated boolean mask of shape (n_atoms,). Raises ------ ValueError If mode is not one of 'set', 'add', 'remove'. Examples -------- :: # Create new mask selecting chain A mask = create_selection_mask("chain A", pdb_df, mode='set') # Add residues 10-20 to existing mask mask = create_selection_mask("resseq 10:20", pdb_df, current_mask=mask, mode='add') # Remove water from mask mask = create_selection_mask("resname HOH", pdb_df, current_mask=mask, mode='remove') """ # Parse the selection selection_mask = parse_phenix_selection(selection_string, pdb_df) # Initialize current mask if not provided if current_mask is None: current_mask = torch.zeros(len(pdb_df), dtype=torch.bool) # Apply mode if mode == "set": return selection_mask elif mode == "add": return current_mask | selection_mask elif mode == "remove": return current_mask & ~selection_mask else: raise ValueError(f"Invalid mode: '{mode}'. Must be 'set', 'add', or 'remove'")
[docs] def state_dict_to_json_serializable(sd: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ Convert a state_dict with tensors to a JSON-serializable format. Parameters ---------- sd : Dict[str, torch.Tensor] State dict with tensor values. Returns ------- Dict[str, Any] JSON-serializable dictionary. """ result = {} for k, v in sd.items(): result[k] = { "data": v.tolist() if v.numel() > 1 else [v.item()], "dtype": str(v.dtype), "shape": list(v.shape) if v.shape else [1], } return result
[docs] def dict_to_state_dict(sd_raw): """ Convert a dict with serialized tensor info to a PyTorch state_dict. Parameters ---------- sd_raw : dict Dictionary where values are dicts with 'data', 'dtype', 'shape' keys. Returns ------- dict State dict with torch.Tensor values. """ sd = {} for k, v in sd_raw.items(): tensor = torch.tensor( v["data"], dtype=getattr(torch, v["dtype"].split(".")[-1]) ) tensor = tensor.reshape(v["shape"]) sd[k] = tensor return sd
[docs] def json_to_state_dicts_separate( json_path: str, ) -> Tuple[ Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], list ]: """ Parse hyperparameter JSON and return state_dicts for component_weighting, geometry_target, and adp_target. Parameters ---------- json_path : str Path to the JSON file containing hyperparameters. Returns ------- Tuple[Dict, Dict, Dict, list] Three state_dicts and a list of unassigned keys: (component_weighting_state, geometry_target_state, adp_target_state, unassigned_keys) """ with open(json_path) as f: data = json.load(f) component_weighting_state = {} geometry_target_state = {} adp_target_state = {} unassigned_keys = [] # Geometry target components geometry_components = { "bond", "angle", "torsion", "planarity", "chiral", "nonbonded", } # ADP target components adp_components = {"simu", "locality", "KL"} for key, value in data.items(): tensor_value = torch.tensor(value) assigned = False if key.startswith("schemes."): # ComponentWeighting state_dict key component_weighting_state[key] = tensor_value assigned = True elif key.startswith("_targets."): # Extract component name (e.g., "_targets.bond._sigma" -> "bond") parts = key.split(".") component_name = parts[1] if component_name in geometry_components: geometry_target_state[key] = tensor_value assigned = True elif component_name in adp_components: adp_target_state[key] = tensor_value assigned = True if not assigned: unassigned_keys.append(key) return ( component_weighting_state, geometry_target_state, adp_target_state, unassigned_keys, )
[docs] def disable_grad_outside_optimizer(optimized_params, all_params): """Set ``requires_grad=False`` on parameters not being optimized. Call this once after creating the optimizer. Non-optimized parameters will no longer contribute to the autograd graph, which means ``ModelFT.get_structure_factor`` will produce structure factors without ``grad_fn`` for frozen models — enabling indefinite caching until parameters change. Parameters ---------- optimized_params : iterable of torch.Tensor Parameters passed to the optimizer. all_params : iterable of torch.Tensor All model parameters (e.g. ``model.parameters()``). """ optimized_ids = set(id(p) for p in optimized_params) for p in all_params: if id(p) not in optimized_ids: p.requires_grad_(False)