Source code for torchref.model.segmented_internal_coordinates

"""
Segmented internal coordinate parametrization for atomic structures.

This module provides the SegmentedInternalCoordinateTensor class which addresses
the "lever arm problem" in internal coordinate parametrization by breaking the
molecular chain into independent segments, each with its own rigid body parameters.

Key features:
- Segments the molecule into groups of N amino acids (default: 3 per segment)
- Each segment has independent internal coordinates (bonds, angles, torsions)
- Each segment has rigid body parameters (position + orientation)
- Shallow spanning trees within segments (depth ~15-30 instead of ~1000)
- Changes in one segment don't propagate to distant segments
- Fully differentiable reconstruction from internal coordinates
- Parallelized construction for fast initialization
- Fused ring systems (indole in TRP, purines, etc.) treated as single rigid groups

This approach solves the lever arm problem where small torsion changes near the
root of a deep tree cause large displacements at distant atoms.
"""

from collections import defaultdict
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from torchref.base.alignment.rotation import rotation_matrix_euler_zyz
from torchref.utils.caching import CachedForwardMixin
from torchref.utils.device_mixin import DeviceMixin


[docs] class SegmentedInternalCoordinateTensor(DeviceMixin, CachedForwardMixin, nn.Module): """ Parameter wrapper using segmented internal coordinates. Stores: per-segment bond_lengths, angles, torsions, segment_positions, segment_orientations Reconstructs: Cartesian xyz on forward() This provides a physically meaningful parametrization that avoids the lever arm problem by breaking the molecule into independent segments, each with shallow spanning trees and rigid body parameters. Parameters ---------- initial_xyz : torch.Tensor Initial Cartesian coordinates of shape (N, 3). pdb : pd.DataFrame PDB DataFrame with columns 'chainid', 'resseq', 'name', 'index'. n_aa_per_segment : int, optional Number of amino acids per segment. Default is 3. bond_cutoff : float, optional Distance cutoff for bond detection in Angstroms. Default is 2.0. requires_grad : bool, optional Whether parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for tensors. Default is same as initial_xyz. device : torch.device, optional Device for tensors. Default is same as initial_xyz. Attributes ---------- n_atoms : int Number of atoms. n_segments : int Number of segments. max_depth : int Maximum depth in any segment's spanning tree. bond_lengths : nn.Parameter Bond length parameters in Angstroms. angles : nn.Parameter Angle parameters in radians. torsions : nn.Parameter Torsion angle parameters in radians. segment_positions : nn.Parameter Absolute positions of segment root atoms. segment_orientations : nn.Parameter ZYZ Euler angle orientations for each segment. """ # Standard amino acid residue names AA_NAMES = frozenset({ "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "MSE", "SEC", })
[docs] def __init__( self, initial_xyz: torch.Tensor, pdb: pd.DataFrame, n_aa_per_segment: int = 3, bond_cutoff: float = 2.0, cif_dict: Optional[dict] = None, requires_grad: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): """ Initialize SegmentedInternalCoordinateTensor. Parameters ---------- initial_xyz : torch.Tensor Initial Cartesian coordinates of shape (N, 3). pdb : pd.DataFrame PDB DataFrame with columns 'chainid', 'resseq', 'name', 'index', 'resname'. n_aa_per_segment : int, optional Number of amino acids per segment. Default is 3. bond_cutoff : float, optional Distance cutoff for bond detection in Angstroms (used as fallback). Default is 2.0. cif_dict : dict, optional CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname]['bonds'] is a DataFrame with 'atom1' and 'atom2' columns. requires_grad : bool, optional Whether parameters should have gradients. Default is True. dtype : torch.dtype, optional Data type for tensors. Default is same as initial_xyz. device : torch.device, optional Device for tensors. Default is same as initial_xyz. """ super().__init__() if dtype is None: dtype = initial_xyz.dtype if device is None: device = initial_xyz.device self._dtype = dtype self._device = device self.n_atoms = initial_xyz.shape[0] self.bond_cutoff = bond_cutoff self.n_aa_per_segment = n_aa_per_segment self.cif_dict = cif_dict # Move initial coordinates to specified device/dtype initial_xyz = initial_xyz.to(dtype=dtype, device=device) # Store a copy of the PDB DataFrame for reference self.pdb = pdb # Build segments from residue information (vectorized) self._build_segments_fast(pdb) # Build internal coordinate trees for each segment (parallelized) self._build_segment_trees_parallel(initial_xyz) # Extract internal coordinates from initial xyz (vectorized) self._extract_internal_coords(initial_xyz, requires_grad)
@property def dtype(self): """Return the dtype of tensors.""" return self._dtype @property def device(self): """Return the device of tensors.""" return self._device def _build_segments_fast(self, pdb: pd.DataFrame) -> None: """ Group atoms into segments based on residue information (vectorized). Uses pandas groupby operations for efficient residue grouping. Parameters ---------- pdb : pd.DataFrame PDB DataFrame with 'chainid', 'resseq', 'name', 'index', 'resname' columns. """ device = self._device # Vectorized residue classification is_protein = pdb["resname"].str.upper().isin(self.AA_NAMES) # Group by (chainid, resseq) - this is O(N log N) grouped = pdb.groupby(["chainid", "resseq"], sort=False) # Build residue info lists efficiently residue_data = [] for (chain_id, resseq), group in grouped: atom_indices = group["index"].values.tolist() res_name = group["resname"].iloc[0] is_aa = res_name.upper() in self.AA_NAMES residue_data.append({ "chain_id": chain_id, "resseq": resseq, "atom_indices": atom_indices, "is_protein": is_aa, "resname": res_name, }) # Separate protein and non-protein residues protein_residues = [r for r in residue_data if r["is_protein"]] non_protein_residues = [r for r in residue_data if not r["is_protein"]] # Group protein residues by chain and sort chain_residues = {} for r in protein_residues: chain_id = r["chain_id"] if chain_id not in chain_residues: chain_residues[chain_id] = [] chain_residues[chain_id].append(r) for chain_id in chain_residues: chain_residues[chain_id].sort(key=lambda x: x["resseq"]) # Build segments segments = [] segment_root_atoms = [] # Create name-to-index lookups for backbone atoms (vectorized) # Use N atom of first residue as root - this ensures connectivity # since peptide chain goes N->CA->C->N->... n_mask = pdb["name"].str.strip() == "N" n_indices = pdb.loc[n_mask, "index"].values n_resseqs = pdb.loc[n_mask, "resseq"].values n_chains = pdb.loc[n_mask, "chainid"].values n_lookup = {(c, r): i for c, r, i in zip(n_chains, n_resseqs, n_indices)} # Fallback to CA if N not found ca_mask = pdb["name"].str.strip() == "CA" ca_indices = pdb.loc[ca_mask, "index"].values ca_resseqs = pdb.loc[ca_mask, "resseq"].values ca_chains = pdb.loc[ca_mask, "chainid"].values ca_lookup = {(c, r): i for c, r, i in zip(ca_chains, ca_resseqs, ca_indices)} # Process each chain for chain_id, residues in chain_residues.items(): n_residues = len(residues) for i in range(0, n_residues, self.n_aa_per_segment): segment_res = residues[i:i + self.n_aa_per_segment] # Collect all atoms in segment segment_atoms = [] for r in segment_res: segment_atoms.extend(r["atom_indices"]) # Use N atom of first residue as root for better connectivity # This ensures BFS can traverse forward through all atoms first_res = segment_res[0] first_key = (first_res["chain_id"], first_res["resseq"]) if first_key in n_lookup: root_atom = n_lookup[first_key] elif first_key in ca_lookup: root_atom = ca_lookup[first_key] else: root_atom = first_res["atom_indices"][0] segments.append(segment_atoms) segment_root_atoms.append(root_atom) # Add non-protein residues as individual segments for r in non_protein_residues: segments.append(r["atom_indices"]) segment_root_atoms.append(r["atom_indices"][0]) self.n_segments = len(segments) # Create mapping tensors (vectorized) atom_to_segment = torch.full((self.n_atoms,), -1, dtype=torch.long, device=device) max_segment_size = max(len(s) for s in segments) if segments else 0 segment_atom_indices = torch.full( (self.n_segments, max_segment_size), -1, dtype=torch.long, device=device ) segment_sizes = torch.zeros(self.n_segments, dtype=torch.long, device=device) # Vectorized assignment using numpy then convert for seg_idx, atom_indices in enumerate(segments): n_atoms_seg = len(atom_indices) segment_sizes[seg_idx] = n_atoms_seg segment_atom_indices[seg_idx, :n_atoms_seg] = torch.tensor( atom_indices, dtype=torch.long, device=device ) atom_to_segment[atom_indices] = seg_idx self.register_buffer("atom_to_segment", atom_to_segment) self.register_buffer("segment_atom_indices", segment_atom_indices) self.register_buffer("segment_sizes", segment_sizes) self.register_buffer( "segment_roots", torch.tensor(segment_root_atoms, dtype=torch.long, device=device) ) self._segments = segments def _build_adjacency_from_cif(self, xyz: torch.Tensor) -> torch.Tensor: """ Build adjacency matrix from CIF dictionary bond definitions. Uses chemical bond definitions from the CIF dictionary for intra-residue bonds and adds peptide bonds (C-N) between consecutive residues. Falls back to distance-based detection for atoms not in the dictionary. Parameters ---------- xyz : torch.Tensor Atomic coordinates of shape (N, 3). Returns ------- torch.Tensor Boolean adjacency matrix of shape (N, N). """ device = self._device n_atoms = self.n_atoms pdb = self.pdb # Collect bond pairs as lists for batch processing bond_idx1_list = [] bond_idx2_list = [] # Build fast lookup: create multi-index for vectorized matching # (chainid, resseq, atom_name) -> atom_index pdb_lookup = pdb.set_index(["chainid", "resseq"]) pdb["name_stripped"] = pdb["name"].str.strip() # Group atoms by residue for fast lookup residue_groups = pdb.groupby(["chainid", "resseq"]) # 1. Build intra-residue bonds from CIF dictionary (vectorized per residue) atoms_with_cif_bonds = set() for (chainid, resseq), group in residue_groups: resname = group["resname"].iloc[0] if resname not in self.cif_dict: continue cif_residue = self.cif_dict[resname] if "bonds" not in cif_residue: continue cif_bonds = cif_residue["bonds"] # Create atom name to index lookup for this residue name_to_idx = dict(zip(group["name_stripped"], group["index"])) # Vectorized bond matching using pandas merge atom1_names = cif_bonds["atom1"].str.strip().values atom2_names = cif_bonds["atom2"].str.strip().values for a1, a2 in zip(atom1_names, atom2_names): if a1 in name_to_idx and a2 in name_to_idx: idx1 = name_to_idx[a1] idx2 = name_to_idx[a2] bond_idx1_list.append(idx1) bond_idx2_list.append(idx2) atoms_with_cif_bonds.add(idx1) atoms_with_cif_bonds.add(idx2) # 2. Build peptide bonds (C-N between consecutive residues) # Pre-compute C and N atom indices per residue c_atoms = pdb[pdb["name_stripped"] == "C"].set_index(["chainid", "resseq"])["index"] n_atoms_df = pdb[pdb["name_stripped"] == "N"].set_index(["chainid", "resseq"])["index"] for chainid in pdb["chainid"].unique(): chain = pdb[pdb["chainid"] == chainid] chain_aa = chain[chain["resname"].isin(self.AA_NAMES)] resseqs = sorted(chain_aa["resseq"].unique()) for i in range(len(resseqs) - 1): resseq1 = resseqs[i] resseq2 = resseqs[i + 1] # Only add peptide bond if residues are consecutive if resseq2 - resseq1 != 1: continue try: idx_c = c_atoms.loc[(chainid, resseq1)] idx_n = n_atoms_df.loc[(chainid, resseq2)] bond_idx1_list.append(idx_c) bond_idx2_list.append(idx_n) atoms_with_cif_bonds.add(idx_c) atoms_with_cif_bonds.add(idx_n) except KeyError: continue # Convert to tensors and build adjacency adjacency = torch.zeros(n_atoms, n_atoms, dtype=torch.bool, device=device) if bond_idx1_list: idx1 = torch.tensor(bond_idx1_list, dtype=torch.long, device=device) idx2 = torch.tensor(bond_idx2_list, dtype=torch.long, device=device) adjacency[idx1, idx2] = True adjacency[idx2, idx1] = True # 3. Fallback: use distance-based detection for atoms without CIF bonds atoms_without_cif = set(range(n_atoms)) - atoms_with_cif_bonds if atoms_without_cif: atoms_list = sorted(atoms_without_cif) atoms_tensor = torch.tensor(atoms_list, dtype=torch.long, device=device) # Compute distances only for atoms without CIF bonds xyz_subset = xyz[atoms_tensor] distances_subset = torch.cdist(xyz_subset, xyz) # Find bonds based on distance bond_mask = (distances_subset < self.bond_cutoff) & (distances_subset > 0.1) # Vectorized addition to adjacency rows, cols = torch.where(bond_mask) adjacency[atoms_tensor[rows], cols] = True adjacency[cols, atoms_tensor[rows]] = True # Clean up temporary column pdb.drop(columns=["name_stripped"], inplace=True, errors="ignore") return adjacency def _build_segment_trees_parallel(self, xyz: torch.Tensor) -> None: """ Build spanning trees for all segments using fully parallel multi-source BFS. Uses vectorized tensor operations to process all segments simultaneously, achieving significant speedup compared to sequential per-segment BFS. Parameters ---------- xyz : torch.Tensor Atomic coordinates of shape (N, 3). """ device = self._device n_atoms = self.n_atoms # Initialize per-atom arrays parent_indices = torch.full((n_atoms,), -1, dtype=torch.long, device=device) grandparent_indices = torch.full((n_atoms,), -1, dtype=torch.long, device=device) great_grandparent_indices = torch.full((n_atoms,), -1, dtype=torch.long, device=device) atom_depth = torch.full((n_atoms,), -1, dtype=torch.long, device=device) # Build molecular graph if self.cif_dict is not None: # Use CIF dictionary for bond definitions (more robust for bad geometry) adjacency = self._build_adjacency_from_cif(xyz) else: # Fallback to distance-based bond detection distances = torch.cdist(xyz, xyz) adjacency = (distances < self.bond_cutoff) & (distances > 0.1) # Build compact neighbor list representation for fast lookups (fully vectorized) neighbor_counts = adjacency.sum(dim=1).to(torch.long) max_neighbors = neighbor_counts.max().item() # Create padded neighbor lists using vectorized scatter # Edge list: (source, target) pairs edge_src, edge_tgt = torch.where(adjacency) # Compute position within each source's neighbor list using segment cumsum # Sort edges by source to group them sorted_indices = torch.argsort(edge_src, stable=True) edge_src_sorted = edge_src[sorted_indices] edge_tgt_sorted = edge_tgt[sorted_indices] # Position within source group using vectorized segment cumsum neighbors_padded = torch.full( (n_atoms, max_neighbors), -1, dtype=torch.long, device=device ) if edge_src_sorted.numel() > 0: # Compute cumsum offset for each source atom offsets = torch.zeros(n_atoms + 1, dtype=torch.long, device=device) offsets[1:] = torch.cumsum(neighbor_counts, dim=0) # For each edge, position = edge_index - offset[source] edge_global_idx = torch.arange(len(edge_src_sorted), device=device) pos_in_group = edge_global_idx - offsets[edge_src_sorted] # Scatter targets into padded array neighbors_padded[edge_src_sorted, pos_in_group] = edge_tgt_sorted # ===== PARALLEL MULTI-SOURCE BFS ===== # Key insight: Run BFS from all segment roots simultaneously # using vectorized tensor operations # Initialize root atoms (depth 0) atom_depth[self.segment_roots] = 0 # Track visited atoms using depth >= 0 # Initialize current frontier with all roots current_frontier = self.segment_roots.clone() current_depth = 0 max_depth = 0 while current_frontier.numel() > 0: # Get all neighbors of current frontier atoms (vectorized) frontier_neighbors = neighbors_padded[current_frontier] # (F, max_neighbors) # Flatten to get candidate neighbors candidates = frontier_neighbors.reshape(-1) # (F * max_neighbors,) # Create parent mapping: which frontier atom is each candidate from frontier_expanded = current_frontier.unsqueeze(1).expand_as(frontier_neighbors) parents = frontier_expanded.reshape(-1) # (F * max_neighbors,) # Filter valid candidates (not -1 padding) valid_mask = candidates >= 0 candidates = candidates[valid_mask] parents = parents[valid_mask] if candidates.numel() == 0: break # Filter candidates: must be unvisited (depth == -1) unvisited_mask = atom_depth[candidates] == -1 candidates = candidates[unvisited_mask] parents = parents[unvisited_mask] if candidates.numel() == 0: break # Filter candidates: must be in same segment as parent candidate_segments = self.atom_to_segment[candidates] parent_segments = self.atom_to_segment[parents] same_segment_mask = candidate_segments == parent_segments candidates = candidates[same_segment_mask] parents = parents[same_segment_mask] if candidates.numel() == 0: break # Handle duplicates: if same atom is reachable from multiple parents, # keep only the first one (using unique with return_inverse) unique_candidates, inverse_indices = torch.unique( candidates, return_inverse=True ) # For each unique candidate, get the first parent first_occurrence = torch.zeros( unique_candidates.numel(), dtype=torch.long, device=device ) # Scatter the parent indices - last write wins, but all parents are valid # for BFS so this is fine first_occurrence.scatter_(0, inverse_indices, parents) # Update depth and parent for new frontier new_depth = current_depth + 1 atom_depth[unique_candidates] = new_depth parent_indices[unique_candidates] = first_occurrence # Update grandparent: grandparent = parent of parent parent_of_new = first_occurrence gp_of_new = parent_indices[parent_of_new] has_grandparent = gp_of_new >= 0 grandparent_indices[unique_candidates[has_grandparent]] = gp_of_new[ has_grandparent ] # Update great-grandparent: ggp = grandparent of parent ggp_of_new = grandparent_indices[parent_of_new] has_ggp = ggp_of_new >= 0 great_grandparent_indices[unique_candidates[has_ggp]] = ggp_of_new[has_ggp] # Update max depth max_depth = new_depth # Next frontier current_frontier = unique_candidates current_depth = new_depth # ===== HANDLE DISCONNECTED COMPONENTS ===== # Chain breaks (missing residues) can create disconnected components # within segments. For each disconnected component, pick a new root # and run BFS from it. max_component_iterations = 100 # Safety limit for _ in range(max_component_iterations): # Find atoms that are still unreached orphan_mask = (atom_depth == -1) & (self.atom_to_segment >= 0) if not orphan_mask.any(): break orphan_indices = torch.where(orphan_mask)[0] # Pick one orphan from each segment as a new root orphan_segments = self.atom_to_segment[orphan_indices] unique_orphan_segments = torch.unique(orphan_segments) # For each segment with orphans, pick the first orphan as new root new_roots = [] for seg in unique_orphan_segments: seg_orphans = orphan_indices[orphan_segments == seg] if seg_orphans.numel() > 0: new_roots.append(seg_orphans[0].item()) if not new_roots: break new_roots = torch.tensor(new_roots, dtype=torch.long, device=device) # Initialize new roots at depth 0 (they become local roots) atom_depth[new_roots] = 0 parent_indices[new_roots] = -1 # Run BFS from new roots current_frontier = new_roots while current_frontier.numel() > 0: frontier_neighbors = neighbors_padded[current_frontier] candidates = frontier_neighbors.reshape(-1) frontier_expanded = current_frontier.unsqueeze(1).expand_as(frontier_neighbors) parents = frontier_expanded.reshape(-1) valid_mask = candidates >= 0 candidates = candidates[valid_mask] parents = parents[valid_mask] if candidates.numel() == 0: break unvisited_mask = atom_depth[candidates] == -1 candidates = candidates[unvisited_mask] parents = parents[unvisited_mask] if candidates.numel() == 0: break candidate_segments = self.atom_to_segment[candidates] parent_segments = self.atom_to_segment[parents] same_segment_mask = candidate_segments == parent_segments candidates = candidates[same_segment_mask] parents = parents[same_segment_mask] if candidates.numel() == 0: break unique_candidates, inverse_indices = torch.unique( candidates, return_inverse=True ) first_occurrence = torch.zeros( unique_candidates.numel(), dtype=torch.long, device=device ) first_occurrence.scatter_(0, inverse_indices, parents) # Get depth of parents parent_depths = atom_depth[first_occurrence] new_depths = parent_depths + 1 atom_depth[unique_candidates] = new_depths parent_indices[unique_candidates] = first_occurrence parent_of_new = first_occurrence gp_of_new = parent_indices[parent_of_new] has_grandparent = gp_of_new >= 0 grandparent_indices[unique_candidates[has_grandparent]] = gp_of_new[ has_grandparent ] ggp_of_new = grandparent_indices[parent_of_new] has_ggp = ggp_of_new >= 0 great_grandparent_indices[unique_candidates[has_ggp]] = ggp_of_new[has_ggp] max_depth = max(max_depth, new_depths.max().item()) current_frontier = unique_candidates # Identify secondary roots (depth-0 atoms that are not segment roots) depth0_mask = atom_depth == 0 is_primary_root = torch.zeros(n_atoms, dtype=torch.bool, device=device) is_primary_root[self.segment_roots] = True secondary_root_mask = depth0_mask & ~is_primary_root secondary_root_indices = torch.where(secondary_root_mask)[0] self.register_buffer("secondary_root_indices", secondary_root_indices) self.max_depth = max_depth self.register_buffer("parent_indices", parent_indices) self.register_buffer("grandparent_indices", grandparent_indices) self.register_buffer("great_grandparent_indices", great_grandparent_indices) self.register_buffer("atom_depth", atom_depth) # Build parameter index mappings (vectorized) self._build_param_indices_vectorized() # Pre-compute depth indices for fast forward pass self._build_depth_indices() # Detect rings (vectorized) self._detect_rings_vectorized(adjacency) def _build_param_indices_vectorized(self) -> None: """ Build mappings from atoms to parameter indices using vectorized operations. """ device = self._device # Vectorized masks has_bond = self.atom_depth >= 1 has_angle = self.atom_depth >= 2 has_torsion = self.atom_depth >= 3 self.n_bonds = has_bond.sum().item() self.n_angles = has_angle.sum().item() self.n_torsions = has_torsion.sum().item() # Vectorized cumsum for index assignment bond_param_indices = torch.full((self.n_atoms,), -1, dtype=torch.long, device=device) angle_param_indices = torch.full((self.n_atoms,), -1, dtype=torch.long, device=device) torsion_param_indices = torch.full((self.n_atoms,), -1, dtype=torch.long, device=device) # Use cumsum for efficient index computation if has_bond.any(): bond_cumsum = torch.cumsum(has_bond.long(), dim=0) - 1 bond_param_indices[has_bond] = bond_cumsum[has_bond] if has_angle.any(): angle_cumsum = torch.cumsum(has_angle.long(), dim=0) - 1 angle_param_indices[has_angle] = angle_cumsum[has_angle] if has_torsion.any(): torsion_cumsum = torch.cumsum(has_torsion.long(), dim=0) - 1 torsion_param_indices[has_torsion] = torsion_cumsum[has_torsion] self.register_buffer("bond_param_indices", bond_param_indices) self.register_buffer("angle_param_indices", angle_param_indices) self.register_buffer("torsion_param_indices", torsion_param_indices) def _build_depth_indices(self) -> None: """ Pre-compute atom indices for each depth level. """ depth_atom_indices = [] depth_parent_indices = [] depth_gp_indices = [] depth_ggp_indices = [] depth_bond_param_indices = [] depth_angle_param_indices = [] depth_torsion_param_indices = [] for d in range(3, self.max_depth + 1): mask = self.atom_depth == d if mask.any(): atom_idx = torch.where(mask)[0] depth_atom_indices.append(atom_idx) depth_parent_indices.append(self.parent_indices[atom_idx]) depth_gp_indices.append(self.grandparent_indices[atom_idx]) depth_ggp_indices.append(self.great_grandparent_indices[atom_idx]) depth_bond_param_indices.append(self.bond_param_indices[atom_idx]) depth_angle_param_indices.append(self.angle_param_indices[atom_idx]) depth_torsion_param_indices.append(self.torsion_param_indices[atom_idx]) else: empty = torch.tensor([], dtype=torch.long, device=self._device) depth_atom_indices.append(empty) depth_parent_indices.append(empty) depth_gp_indices.append(empty) depth_ggp_indices.append(empty) depth_bond_param_indices.append(empty) depth_angle_param_indices.append(empty) depth_torsion_param_indices.append(empty) self._depth_atom_indices = depth_atom_indices self._depth_parent_indices = depth_parent_indices self._depth_gp_indices = depth_gp_indices self._depth_ggp_indices = depth_ggp_indices self._depth_bond_param_indices = depth_bond_param_indices self._depth_angle_param_indices = depth_angle_param_indices self._depth_torsion_param_indices = depth_torsion_param_indices def _detect_rings_vectorized(self, adjacency: torch.Tensor) -> None: """ Detect rings in the molecular graph using vectorized operations. Fused ring systems (like indole in Tryptophan) are automatically merged into single rigid groups. Any rings that share atoms are combined. Parameters ---------- adjacency : torch.Tensor Boolean adjacency matrix of shape (N, N). """ device = self._device n_atoms = self.n_atoms ring_member_mask = torch.zeros(n_atoms, dtype=torch.bool, device=device) ring_group_id = torch.full((n_atoms,), -1, dtype=torch.long, device=device) # Find back edges using vectorized operations # An edge (i, j) is a back edge if neither i is parent of j nor j is parent of i edge_i, edge_j = torch.where(adjacency.triu(diagonal=1)) # Check if edges are in spanning tree is_tree_edge = ( (self.parent_indices[edge_i] == edge_j) | (self.parent_indices[edge_j] == edge_i) ) # Check if in same segment same_segment = self.atom_to_segment[edge_i] == self.atom_to_segment[edge_j] # Back edges: not tree edges, same segment back_edge_mask = ~is_tree_edge & same_segment back_edge_i = edge_i[back_edge_mask] back_edge_j = edge_j[back_edge_mask] # First pass: detect individual rings from back edges individual_rings = [] for idx in range(len(back_edge_i)): i = back_edge_i[idx].item() j = back_edge_j[idx].item() # Find path from i to j through tree ancestors_i = set() current = i while current >= 0: ancestors_i.add(current) current = self.parent_indices[current].item() ring_atoms = [] current = j common_ancestor = None while current >= 0: ring_atoms.append(current) if current in ancestors_i: common_ancestor = current break current = self.parent_indices[current].item() if common_ancestor is None: continue # Add path from i to common ancestor current = i while current != common_ancestor: ring_atoms.append(current) current = self.parent_indices[current].item() individual_rings.append(set(ring_atoms)) # Second pass: merge overlapping rings into ring systems # This handles fused rings like indole (TRP), naphthalene, etc. ring_systems = self._merge_overlapping_rings(individual_rings) # Build final ring data structures ring_anchors = [] ring_members_list = [] for ring_idx, ring_atoms_set in enumerate(ring_systems): ring_atoms = list(ring_atoms_set) # Mark ring atoms for atom in ring_atoms: ring_member_mask[atom] = True ring_group_id[atom] = ring_idx # Anchor is atom closest to root (smallest depth) depths = self.atom_depth[ring_atoms] anchor_idx = ring_atoms[depths.argmin().item()] ring_anchors.append(anchor_idx) ring_members_list.append(ring_atoms) self.register_buffer("ring_member_mask", ring_member_mask) self.register_buffer("ring_group_id", ring_group_id) self.n_rings = len(ring_anchors) if self.n_rings > 0: self.register_buffer( "ring_anchor_atoms", torch.tensor(ring_anchors, dtype=torch.long, device=device) ) max_ring_size = max(len(r) for r in ring_members_list) ring_members_tensor = torch.full( (self.n_rings, max_ring_size), -1, dtype=torch.long, device=device ) ring_sizes = torch.zeros(self.n_rings, dtype=torch.long, device=device) for i, members in enumerate(ring_members_list): ring_sizes[i] = len(members) ring_members_tensor[i, :len(members)] = torch.tensor( members, dtype=torch.long, device=device ) self.register_buffer("ring_members", ring_members_tensor) self.register_buffer("ring_sizes", ring_sizes) else: self.register_buffer( "ring_anchor_atoms", torch.tensor([], dtype=torch.long, device=device) ) self.register_buffer( "ring_members", torch.tensor([], dtype=torch.long, device=device).reshape(0, 0) ) self.register_buffer( "ring_sizes", torch.tensor([], dtype=torch.long, device=device) ) def _merge_overlapping_rings(self, rings: list) -> list: """ Merge overlapping rings into unified ring systems using union-find. This handles fused ring systems like: - Indole in Tryptophan (5-membered + 6-membered fused rings) - Purine bases (adenine, guanine) - Naphthalene-like systems Parameters ---------- rings : list of set List of sets, each containing atom indices for a detected ring. Returns ------- list of set List of merged ring systems, where overlapping rings are combined. """ if len(rings) == 0: return [] n_rings = len(rings) # Union-find data structure parent = list(range(n_rings)) rank = [0] * n_rings def find(x): if parent[x] != x: parent[x] = find(parent[x]) # Path compression return parent[x] def union(x, y): px, py = find(x), find(y) if px == py: return # Union by rank if rank[px] < rank[py]: px, py = py, px parent[py] = px if rank[px] == rank[py]: rank[px] += 1 # Merge rings that share any atoms for i in range(n_rings): for j in range(i + 1, n_rings): # If rings share any atoms, merge them if rings[i] & rings[j]: # Set intersection union(i, j) # Group rings by their root in union-find groups = defaultdict(list) for i in range(n_rings): groups[find(i)].append(i) # Create merged ring systems ring_systems = [] for ring_indices in groups.values(): # Merge all atoms from rings in this group merged_atoms = set() for ring_idx in ring_indices: merged_atoms |= rings[ring_idx] ring_systems.append(merged_atoms) return ring_systems def _extract_internal_coords( self, xyz: torch.Tensor, requires_grad: bool = True ) -> None: """ Extract internal coordinates from Cartesian coordinates (vectorized). Parameters ---------- xyz : torch.Tensor Atomic coordinates of shape (N, 3). requires_grad : bool Whether parameters should have gradients. """ device = self._device dtype = self._dtype # Extract bond lengths (vectorized) bond_mask = self.atom_depth >= 1 if bond_mask.any(): child_xyz = xyz[bond_mask] parent_xyz = xyz[self.parent_indices[bond_mask]] bond_lengths = torch.linalg.norm(child_xyz - parent_xyz, dim=-1) else: bond_lengths = torch.tensor([], dtype=dtype, device=device) # Extract angles (vectorized) angle_mask = self.atom_depth >= 2 if angle_mask.any(): child_xyz = xyz[angle_mask] parent_xyz = xyz[self.parent_indices[angle_mask]] grandparent_xyz = xyz[self.grandparent_indices[angle_mask]] v1 = child_xyz - parent_xyz v2 = grandparent_xyz - parent_xyz cos_angles = torch.sum(v1 * v2, dim=-1) / ( torch.linalg.norm(v1, dim=-1) * torch.linalg.norm(v2, dim=-1) + 1e-10 ) cos_angles = torch.clamp(cos_angles, -1.0, 1.0) angles = torch.acos(cos_angles) else: angles = torch.tensor([], dtype=dtype, device=device) # Extract torsions (vectorized) torsion_mask = self.atom_depth >= 3 if torsion_mask.any(): child_xyz = xyz[torsion_mask] parent_xyz = xyz[self.parent_indices[torsion_mask]] grandparent_xyz = xyz[self.grandparent_indices[torsion_mask]] great_grandparent_xyz = xyz[self.great_grandparent_indices[torsion_mask]] torsions = self._compute_torsion_angles( great_grandparent_xyz, grandparent_xyz, parent_xyz, child_xyz ) else: torsions = torch.tensor([], dtype=dtype, device=device) # Extract segment positions segment_positions = xyz[self.segment_roots].clone() # Extract secondary root positions (disconnected component roots) if self.secondary_root_indices.numel() > 0: secondary_root_positions = xyz[self.secondary_root_indices].clone() else: secondary_root_positions = torch.zeros(0, 3, dtype=dtype, device=device) # Initialize segment orientations to zero (identity rotation) segment_orientations = torch.zeros( self.n_segments, 3, dtype=dtype, device=device ) # Extract ring local coordinates (vectorized where possible) if self.n_rings > 0: ring_local_coords = self._extract_ring_local_coords_vectorized(xyz) self.register_buffer("ring_local_coords", ring_local_coords) else: self.register_buffer( "ring_local_coords", torch.tensor([], dtype=dtype, device=device).reshape(0, 0, 3) ) # Setup shallow atom references (vectorized) self._setup_shallow_atom_references_vectorized(xyz) # Register parameters self.bond_lengths = nn.Parameter( bond_lengths.clone(), requires_grad=requires_grad ) self.angles = nn.Parameter(angles.clone(), requires_grad=requires_grad) self.torsions = nn.Parameter(torsions.clone(), requires_grad=requires_grad) self.segment_positions = nn.Parameter( segment_positions.clone(), requires_grad=requires_grad ) self.segment_orientations = nn.Parameter( segment_orientations.clone(), requires_grad=requires_grad ) self.secondary_root_positions = nn.Parameter( secondary_root_positions.clone(), requires_grad=requires_grad ) # Initialize refinable mask self.register_buffer( "refinable_mask", torch.ones(self.n_atoms, dtype=torch.bool, device=device) ) self.register_buffer("fixed_xyz", xyz.clone()) def _setup_shallow_atom_references_vectorized(self, xyz: torch.Tensor) -> None: """ Store reference directions for depth-1 and depth-2 atoms (vectorized). Parameters ---------- xyz : torch.Tensor Atomic coordinates. """ device = self._device dtype = self._dtype # Depth-1 atoms (vectorized) depth1_mask = self.atom_depth == 1 n_depth1 = depth1_mask.sum().item() if n_depth1 > 0: depth1_indices = torch.where(depth1_mask)[0] # Vectorized direction computation parent_idx = self.parent_indices[depth1_indices] directions = xyz[depth1_indices] - xyz[parent_idx] norms = torch.linalg.norm(directions, dim=-1, keepdim=True).clamp(min=1e-10) depth1_dirs = directions / norms # Create index mapping depth1_atom_to_dir_idx = torch.full( (self.n_atoms,), -1, dtype=torch.long, device=device ) depth1_atom_to_dir_idx[depth1_indices] = torch.arange( n_depth1, dtype=torch.long, device=device ) self.register_buffer("depth1_dirs", depth1_dirs) self.register_buffer("depth1_atom_to_dir_idx", depth1_atom_to_dir_idx) else: self.register_buffer( "depth1_dirs", torch.zeros(0, 3, dtype=dtype, device=device) ) self.register_buffer( "depth1_atom_to_dir_idx", torch.full((self.n_atoms,), -1, dtype=torch.long, device=device) ) # Depth-2 atoms (vectorized) depth2_mask = self.atom_depth == 2 n_depth2 = depth2_mask.sum().item() if n_depth2 > 0: depth2_indices = torch.where(depth2_mask)[0] parent_idx = self.parent_indices[depth2_indices] grandparent_idx = self.grandparent_indices[depth2_indices] # Bond direction (vectorized) v_bond = xyz[parent_idx] - xyz[grandparent_idx] v_bond_norm = torch.linalg.norm(v_bond, dim=-1, keepdim=True).clamp(min=1e-10) v_bond = v_bond / v_bond_norm # Child direction v_child = xyz[depth2_indices] - xyz[parent_idx] # Perpendicular component (vectorized) proj = torch.sum(v_child * v_bond, dim=-1, keepdim=True) v_perp = v_child - proj * v_bond v_perp_norm = torch.linalg.norm(v_perp, dim=-1, keepdim=True) # Handle collinear cases collinear = v_perp_norm.squeeze(-1) < 1e-10 if collinear.any(): # For collinear atoms, use arbitrary perpendicular default_perp = torch.zeros_like(v_perp) nearly_x = torch.abs(v_bond[:, 0]) >= 0.9 default_perp[~nearly_x, 0] = 1.0 default_perp[nearly_x, 1] = 1.0 # Make perpendicular proj_default = torch.sum(default_perp * v_bond, dim=-1, keepdim=True) default_perp = default_perp - proj_default * v_bond default_perp = default_perp / torch.linalg.norm( default_perp, dim=-1, keepdim=True ).clamp(min=1e-10) v_perp[collinear] = default_perp[collinear] v_perp_norm[collinear] = 1.0 depth2_perps = v_perp / v_perp_norm.clamp(min=1e-10) # Create index mapping depth2_atom_to_perp_idx = torch.full( (self.n_atoms,), -1, dtype=torch.long, device=device ) depth2_atom_to_perp_idx[depth2_indices] = torch.arange( n_depth2, dtype=torch.long, device=device ) self.register_buffer("depth2_perps", depth2_perps) self.register_buffer("depth2_atom_to_perp_idx", depth2_atom_to_perp_idx) else: self.register_buffer( "depth2_perps", torch.zeros(0, 3, dtype=dtype, device=device) ) self.register_buffer( "depth2_atom_to_perp_idx", torch.full((self.n_atoms,), -1, dtype=torch.long, device=device) ) def _extract_ring_local_coords_vectorized(self, xyz: torch.Tensor) -> torch.Tensor: """ Extract local coordinates for ring atoms (fully vectorized). Parameters ---------- xyz : torch.Tensor Atomic coordinates. Returns ------- torch.Tensor Ring local coordinates of shape (n_rings, max_ring_size, 3). """ device = self._device dtype = self._dtype max_ring_size = self.ring_members.shape[1] if self.n_rings > 0 else 0 ring_local_coords = torch.zeros( self.n_rings, max_ring_size, 3, dtype=dtype, device=device ) if self.n_rings == 0: return ring_local_coords # Get anchor positions anchor_pos = xyz[self.ring_anchor_atoms] # (n_rings, 3) parent_idx = self.parent_indices[self.ring_anchor_atoms] # Compute z-axis for all rings (vectorized) valid_parent = parent_idx >= 0 parent_pos = torch.zeros_like(anchor_pos) parent_pos[valid_parent] = xyz[parent_idx[valid_parent]] z_axis = anchor_pos - parent_pos z_norm = torch.linalg.norm(z_axis, dim=-1, keepdim=True) z_axis = torch.where( z_norm > 1e-10, z_axis / z_norm.clamp(min=1e-10), torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) ) z_axis[~valid_parent] = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) # Compute perpendicular axes (vectorized) x_base = torch.zeros_like(z_axis) x_base[:, 0] = 1.0 nearly_x = torch.abs(z_axis[:, 0]) >= 0.9 x_base[nearly_x, 0] = 0.0 x_base[nearly_x, 1] = 1.0 x_axis = x_base - (x_base * z_axis).sum(dim=-1, keepdim=True) * z_axis x_axis = x_axis / torch.linalg.norm(x_axis, dim=-1, keepdim=True).clamp(min=1e-10) y_axis = torch.linalg.cross(z_axis, x_axis) # Build rotation matrices: R = [x, y, z] as columns R = torch.stack([x_axis, y_axis, z_axis], dim=-1) # (n_rings, 3, 3) # Fully vectorized ring member processing # Expand ring_members to get all atom positions at once # Use -1 padding to create valid mask valid_mask = self.ring_members >= 0 # (n_rings, max_ring_size) # Replace -1 with 0 for indexing (will be masked out) safe_members = self.ring_members.clamp(min=0) # Get all member positions: (n_rings, max_ring_size, 3) member_pos = xyz[safe_members] # Compute offsets from anchor: (n_rings, max_ring_size, 3) offsets = member_pos - anchor_pos.unsqueeze(1) # Transform to local coords using batched matmul # offsets: (n_rings, max_ring_size, 3) # R: (n_rings, 3, 3) # Result: (n_rings, max_ring_size, 3) ring_local_coords = torch.bmm( offsets, R # (n_rings, 3, 3) ) # Zero out invalid entries ring_local_coords = ring_local_coords * valid_mask.unsqueeze(-1) return ring_local_coords @staticmethod def _compute_torsion_angles( p1: torch.Tensor, p2: torch.Tensor, p3: torch.Tensor, p4: torch.Tensor ) -> torch.Tensor: """ Compute torsion angles for batched positions (vectorized). Parameters ---------- p1, p2, p3, p4 : torch.Tensor Atom positions of shape (N, 3). Returns ------- torch.Tensor Torsion angles in radians, shape (N,). """ b1 = p2 - p1 b2 = p3 - p2 b3 = p4 - p3 n1 = torch.linalg.cross(b1, b2) n2 = torch.linalg.cross(b2, b3) n1 = n1 / torch.linalg.norm(n1, dim=-1, keepdim=True).clamp(min=1e-10) n2 = n2 / torch.linalg.norm(n2, dim=-1, keepdim=True).clamp(min=1e-10) b2_norm = b2 / torch.linalg.norm(b2, dim=-1, keepdim=True).clamp(min=1e-10) m1 = torch.linalg.cross(n1, b2_norm) x = torch.sum(n1 * n2, dim=-1) y = torch.sum(m1 * n2, dim=-1) return torch.atan2(y, x)
[docs] def forward(self) -> torch.Tensor: """ Reconstruct Cartesian xyz from internal coordinates. Uses fully vectorized operations for maximum performance. Returns ------- torch.Tensor Reconstructed Cartesian coordinates of shape (N, 3). """ xyz = torch.zeros( self.n_atoms, 3, dtype=self._dtype, device=self._device ) # Get rotation matrices for all segments (batched) R_matrices = rotation_matrix_euler_zyz( self.segment_orientations ) # (n_segments, 3, 3) # Place segment root atoms (depth=0) xyz[self.segment_roots] = self.segment_positions # Place secondary roots (disconnected component roots) if self.secondary_root_indices.numel() > 0: xyz[self.secondary_root_indices] = self.secondary_root_positions # Place depth-1 atoms (vectorized) xyz = self._place_depth1_atoms(xyz, R_matrices) # Place depth-2 atoms (vectorized) xyz = self._place_depth2_atoms(xyz, R_matrices) # Place depth-3+ atoms using NeRF (vectorized per depth) for depth_idx in range(len(self._depth_atom_indices)): xyz = self._place_atoms_at_depth_fast(xyz, depth_idx) # Place rigid ring atoms (vectorized) xyz = self._place_rigid_rings(xyz) # Apply frozen coordinates if not self.refinable_mask.all(): frozen_mask = ~self.refinable_mask xyz[frozen_mask] = self.fixed_xyz[frozen_mask] return xyz
def _place_depth1_atoms( self, xyz: torch.Tensor, R_matrices: torch.Tensor ) -> torch.Tensor: """ Place atoms at depth 1 (vectorized). Parameters ---------- xyz : torch.Tensor Current coordinates. R_matrices : torch.Tensor Segment rotation matrices of shape (n_segments, 3, 3). Returns ------- torch.Tensor Updated coordinates. """ mask = self.atom_depth == 1 if not mask.any(): return xyz xyz = xyz.clone() atom_indices = torch.where(mask)[0] seg_ids = self.atom_to_segment[atom_indices] # Get actual parent positions from xyz (handles both primary and secondary roots) parent_idx = self.parent_indices[atom_indices] parent_positions = xyz[parent_idx] # Get rotation matrices based on segment R = R_matrices[seg_ids] # Get stored directions dir_indices = self.depth1_atom_to_dir_idx[atom_indices] base_dirs = self.depth1_dirs[dir_indices] # Rotate directions (batched matmul) rotated_dirs = torch.bmm(R, base_dirs.unsqueeze(-1)).squeeze(-1) # Get bond lengths bond_idx = self.bond_param_indices[mask] bond_lengths = self.bond_lengths[bond_idx] # Compute positions relative to actual parent (vectorized) new_positions = parent_positions + bond_lengths.unsqueeze(-1) * rotated_dirs xyz[mask] = new_positions return xyz def _place_depth2_atoms( self, xyz: torch.Tensor, R_matrices: torch.Tensor ) -> torch.Tensor: """ Place atoms at depth 2 (vectorized). Parameters ---------- xyz : torch.Tensor Current coordinates. R_matrices : torch.Tensor Segment rotation matrices. Returns ------- torch.Tensor Updated coordinates. """ mask = self.atom_depth == 2 if not mask.any(): return xyz xyz = xyz.clone() # Get positions (vectorized gather) parent_idx = self.parent_indices[mask] grandparent_idx = self.grandparent_indices[mask] parent_pos = xyz[parent_idx] grandparent_pos = xyz[grandparent_idx] # Get parameters bond_idx = self.bond_param_indices[mask] angle_idx = self.angle_param_indices[mask] bond_lengths = self.bond_lengths[bond_idx] angles = self.angles[angle_idx] # Build local frame (vectorized) bc = parent_pos - grandparent_pos bc_norm = torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10) bc_unit = bc / bc_norm # Get perpendicular references atom_indices = torch.where(mask)[0] perp_idx = self.depth2_atom_to_perp_idx[atom_indices] ref_perp = self.depth2_perps[perp_idx] # Apply segment rotation (batched) seg_ids = self.atom_to_segment[atom_indices] R = R_matrices[seg_ids] rotated_ref = torch.bmm(R, ref_perp.unsqueeze(-1)).squeeze(-1) # Make perpendicular to bc_unit (vectorized) proj = torch.sum(rotated_ref * bc_unit, dim=-1, keepdim=True) rotated_ref = rotated_ref - proj * bc_unit ref_norm = torch.linalg.norm(rotated_ref, dim=-1, keepdim=True).clamp(min=1e-10) rotated_ref = rotated_ref / ref_norm # Compute position (vectorized) theta_internal = torch.pi - angles dx = bond_lengths * torch.cos(theta_internal) dy = bond_lengths * torch.sin(theta_internal) new_positions = ( parent_pos + dx.unsqueeze(-1) * bc_unit + dy.unsqueeze(-1) * rotated_ref ) xyz[mask] = new_positions return xyz def _place_atoms_at_depth_fast( self, xyz: torch.Tensor, depth_idx: int ) -> torch.Tensor: """ Place atoms at a depth level using pre-computed indices (vectorized). Parameters ---------- xyz : torch.Tensor Current coordinates. depth_idx : int Index into pre-computed depth arrays. Returns ------- torch.Tensor Updated coordinates. """ atom_idx = self._depth_atom_indices[depth_idx] if atom_idx.numel() == 0: return xyz xyz = xyz.clone() # Use pre-computed indices (vectorized gather) parent_idx = self._depth_parent_indices[depth_idx] gp_idx = self._depth_gp_indices[depth_idx] ggp_idx = self._depth_ggp_indices[depth_idx] p1 = xyz[ggp_idx] p2 = xyz[gp_idx] p3 = xyz[parent_idx] # Get parameters bond_idx = self._depth_bond_param_indices[depth_idx] angle_idx = self._depth_angle_param_indices[depth_idx] torsion_idx = self._depth_torsion_param_indices[depth_idx] d = self.bond_lengths[bond_idx] theta = self.angles[angle_idx] phi = self.torsions[torsion_idx] # Build local coordinate frame (vectorized) bc = p3 - p2 bc = bc / torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10) ab = p2 - p1 n = torch.linalg.cross(ab, bc) n = n / torch.linalg.norm(n, dim=-1, keepdim=True).clamp(min=1e-10) m = torch.linalg.cross(n, bc) # Compute positions (vectorized) theta_internal = torch.pi - theta sin_theta = torch.sin(theta_internal) cos_theta = torch.cos(theta_internal) dx = d * cos_theta dy = d * sin_theta * torch.cos(phi) dz = d * sin_theta * torch.sin(phi) new_pos = ( p3 + dx.unsqueeze(-1) * bc + dy.unsqueeze(-1) * m - dz.unsqueeze(-1) * n ) xyz[atom_idx] = new_pos return xyz def _place_rigid_rings(self, xyz: torch.Tensor) -> torch.Tensor: """ Place ring atoms as rigid groups (vectorized). Parameters ---------- xyz : torch.Tensor Current coordinates. Returns ------- torch.Tensor Updated coordinates with ring atoms placed. """ if self.n_rings == 0: return xyz # Get anchor positions (vectorized) anchor_pos = xyz[self.ring_anchor_atoms] parent_idx = self.parent_indices[self.ring_anchor_atoms] valid_parent_mask = parent_idx >= 0 parent_pos = torch.zeros_like(anchor_pos) parent_pos[valid_parent_mask] = xyz[parent_idx[valid_parent_mask]] # Compute z-axis (vectorized) z_axis = anchor_pos - parent_pos z_norm = torch.linalg.norm(z_axis, dim=-1, keepdim=True) z_axis = torch.where( z_norm > 1e-10, z_axis / z_norm.clamp(min=1e-10), torch.tensor([0.0, 0.0, 1.0], dtype=self._dtype, device=self._device) ) z_axis[~valid_parent_mask] = torch.tensor( [0.0, 0.0, 1.0], dtype=self._dtype, device=self._device ) # Compute perpendicular axes (vectorized) x_base = torch.zeros_like(z_axis) x_base[:, 0] = 1.0 nearly_x = torch.abs(z_axis[:, 0]) >= 0.9 x_base[nearly_x, 0] = 0.0 x_base[nearly_x, 1] = 1.0 x_axis = x_base - (x_base * z_axis).sum(dim=-1, keepdim=True) * z_axis x_axis = x_axis / torch.linalg.norm(x_axis, dim=-1, keepdim=True).clamp(min=1e-10) y_axis = torch.linalg.cross(z_axis, x_axis) R = torch.stack([x_axis, y_axis, z_axis], dim=-1) # Transform local coords to global (batched matmul) global_offsets = torch.bmm( self.ring_local_coords, R.transpose(-1, -2) ) global_pos = global_offsets + anchor_pos.unsqueeze(1) # Scatter results (vectorized) valid_mask = self.ring_members >= 0 anchor_expanded = self.ring_anchor_atoms.unsqueeze(1) valid_mask = valid_mask & (self.ring_members != anchor_expanded) ring_indices, local_indices = torch.where(valid_mask) atom_indices = self.ring_members[ring_indices, local_indices] positions = global_pos[ring_indices, local_indices] xyz[atom_indices] = positions return xyz
[docs] def shake(self, magnitude: float = 0.1) -> torch.Tensor: """ Add Gaussian noise to internal parameters. Parameters ---------- magnitude : float, optional Standard deviation of Gaussian noise. Default is 0.1. Returns ------- torch.Tensor New Cartesian coordinates after perturbation. """ with torch.no_grad(): if self.bond_lengths.numel() > 0: self.bond_lengths.data += torch.randn_like(self.bond_lengths) * magnitude self.bond_lengths.data = torch.clamp(self.bond_lengths.data, min=0.5) if self.angles.numel() > 0: self.angles.data += torch.randn_like(self.angles) * magnitude self.angles.data = torch.clamp( self.angles.data, min=0.1, max=torch.pi - 0.1 ) if self.torsions.numel() > 0: self.torsions.data += torch.randn_like(self.torsions) * magnitude self.torsions.data = torch.atan2( torch.sin(self.torsions.data), torch.cos(self.torsions.data) ) if self.segment_positions.numel() > 0: self.segment_positions.data += ( torch.randn_like(self.segment_positions) * magnitude ) if self.secondary_root_positions.numel() > 0: self.secondary_root_positions.data += ( torch.randn_like(self.secondary_root_positions) * magnitude ) return self.forward()
# ===== Freeze/Unfreeze Methods =====
[docs] def fix( self, selection: Union[torch.Tensor, slice, None] = None, freeze_at_current: bool = True ) -> None: """ Fix (freeze) atoms to use fixed xyz coordinates. Parameters ---------- selection : torch.Tensor, slice, or None Boolean mask or indices of atoms to fix. freeze_at_current : bool, optional If True, store current coordinates for selected atoms. """ if selection is None: selection = slice(None) if isinstance(selection, torch.Tensor) and selection.dtype == torch.bool: mask = selection elif isinstance(selection, torch.Tensor): mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device) mask[selection] = True elif isinstance(selection, slice): mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device) mask[selection] = True else: raise TypeError( f"selection must be Tensor, slice, or None, got {type(selection)}" ) if freeze_at_current: with torch.no_grad(): current_xyz = self.forward() self.fixed_xyz[mask] = current_xyz[mask] self.refinable_mask[mask] = False
[docs] def freeze( self, selection: Union[torch.Tensor, slice, None] = None, freeze_at_current: bool = True ) -> None: """Alias for fix().""" self.fix(selection, freeze_at_current)
[docs] def refine( self, selection: Union[torch.Tensor, slice, None] = None, rebuild: bool = True ) -> None: """ Make atoms refinable. Parameters ---------- selection : torch.Tensor, slice, or None Boolean mask or indices of atoms to make refinable. rebuild : bool, optional If True, rebuild internal coordinates from fixed_xyz. """ if selection is None: selection = slice(None) if isinstance(selection, torch.Tensor) and selection.dtype == torch.bool: mask = selection elif isinstance(selection, torch.Tensor): mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device) mask[selection] = True elif isinstance(selection, slice): mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device) mask[selection] = True else: raise TypeError( f"selection must be Tensor, slice, or None, got {type(selection)}" ) if rebuild: self._update_internal_coords_from_xyz(self.fixed_xyz, mask) self.refinable_mask[mask] = True
[docs] def unfreeze( self, selection: Union[torch.Tensor, slice, None] = None, rebuild: bool = True ) -> None: """Alias for refine().""" self.refine(selection, rebuild)
[docs] def fix_all(self, freeze_at_current: bool = True) -> None: """Fix all atoms.""" self.fix(None, freeze_at_current)
[docs] def freeze_all(self, freeze_at_current: bool = True) -> None: """Alias for fix_all().""" self.fix_all(freeze_at_current)
[docs] def refine_all(self, rebuild: bool = True) -> None: """Make all atoms refinable.""" self.refine(None, rebuild)
[docs] def unfreeze_all(self, rebuild: bool = True) -> None: """Alias for refine_all().""" self.refine_all(rebuild)
def _update_internal_coords_from_xyz( self, xyz: torch.Tensor, mask: torch.Tensor ) -> None: """ Update internal coordinate parameters from xyz (vectorized). Parameters ---------- xyz : torch.Tensor Atomic coordinates. mask : torch.Tensor Boolean mask indicating which atoms to update. """ # Update bond lengths (vectorized) bond_update_mask = mask & (self.atom_depth >= 1) if bond_update_mask.any(): child_xyz = xyz[bond_update_mask] parent_xyz = xyz[self.parent_indices[bond_update_mask]] new_bonds = torch.linalg.norm(child_xyz - parent_xyz, dim=-1) bond_idx = self.bond_param_indices[bond_update_mask] with torch.no_grad(): self.bond_lengths.data[bond_idx] = new_bonds # Update angles (vectorized) angle_update_mask = mask & (self.atom_depth >= 2) if angle_update_mask.any(): child_xyz = xyz[angle_update_mask] parent_xyz = xyz[self.parent_indices[angle_update_mask]] grandparent_xyz = xyz[self.grandparent_indices[angle_update_mask]] v1 = child_xyz - parent_xyz v2 = grandparent_xyz - parent_xyz cos_angles = torch.sum(v1 * v2, dim=-1) / ( torch.linalg.norm(v1, dim=-1) * torch.linalg.norm(v2, dim=-1) + 1e-10 ) cos_angles = torch.clamp(cos_angles, -1.0, 1.0) new_angles = torch.acos(cos_angles) angle_idx = self.angle_param_indices[angle_update_mask] with torch.no_grad(): self.angles.data[angle_idx] = new_angles # Update torsions (vectorized) torsion_update_mask = mask & (self.atom_depth >= 3) if torsion_update_mask.any(): child_xyz = xyz[torsion_update_mask] parent_xyz = xyz[self.parent_indices[torsion_update_mask]] grandparent_xyz = xyz[self.grandparent_indices[torsion_update_mask]] great_grandparent_xyz = xyz[ self.great_grandparent_indices[torsion_update_mask] ] new_torsions = self._compute_torsion_angles( great_grandparent_xyz, grandparent_xyz, parent_xyz, child_xyz ) torsion_idx = self.torsion_param_indices[torsion_update_mask] with torch.no_grad(): self.torsions.data[torsion_idx] = new_torsions # Update segment positions (vectorized where possible) root_in_mask = mask[self.segment_roots] if root_in_mask.any(): with torch.no_grad(): self.segment_positions.data[root_in_mask] = xyz[ self.segment_roots[root_in_mask] ] # Update secondary root positions if self.secondary_root_indices.numel() > 0: sec_root_in_mask = mask[self.secondary_root_indices] if sec_root_in_mask.any(): with torch.no_grad(): self.secondary_root_positions.data[sec_root_in_mask] = xyz[ self.secondary_root_indices[sec_root_in_mask] ] @property def n_refinable(self) -> int: """Return the number of refinable atoms.""" return self.refinable_mask.sum().item() @property def n_fixed(self) -> int: """Return the number of fixed atoms.""" return (~self.refinable_mask).sum().item() def __repr__(self) -> str: n_secondary = self.secondary_root_indices.numel() return ( f"SegmentedInternalCoordinateTensor(n_atoms={self.n_atoms}, " f"n_segments={self.n_segments}, n_aa_per_segment={self.n_aa_per_segment}, " f"n_secondary_roots={n_secondary}, " f"n_refinable={self.n_refinable}, n_fixed={self.n_fixed}, " f"n_bonds={self.n_bonds}, n_angles={self.n_angles}, " f"n_torsions={self.n_torsions}, n_rings={self.n_rings}, " f"max_depth={self.max_depth}, device={self._device})" )