"""
Closed segmented internal coordinate parametrization for atomic structures.
This module provides the ClosedSegmentedInternalCoordinateTensor class which
extends the segmented approach with analytical chain closure. Between larger
segments (~18 residues each), 3-residue "junctions" are placed whose backbone
torsions (phi, psi) are slave DOFs solved by Newton's method. Backward
gradients flow via the Implicit Function Theorem, giving exact gradients
without differentiating through the solver.
Key differences from SegmentedInternalCoordinateTensor:
- Larger segments (default 18 residues vs 3)
- Junction residues between segments maintain chain continuity
- Junction phi/psi are solved, not free parameters
- IFT provides exact gradients through the closure constraint
- Junctions preferentially placed in loop regions
"""
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchref.base.alignment.rotation import rotation_matrix_euler_zyz
from torchref.base.chain_closure import (
identify_backbone_atoms,
get_chain_residues,
compute_backbone_torsions,
estimate_secondary_structure,
plan_junction_placement,
get_junction_backbone_indices,
backbone_fk_junction,
closure_residual,
JunctionSolver,
AA_NAMES,
)
from torchref.utils.caching import CachedForwardMixin
from torchref.utils.device_mixin import DeviceMixin
logger = logging.getLogger(__name__)
[docs]
class ClosedSegmentedInternalCoordinateTensor(DeviceMixin, CachedForwardMixin, nn.Module):
"""
Parameter wrapper using segmented internal coordinates with chain closure.
Stores: per-segment bond_lengths, angles, torsions, segment_positions,
segment_orientations. Junction backbone torsions are slave DOFs solved
by Newton's method with IFT gradients.
Parameters
----------
initial_xyz : torch.Tensor
Initial Cartesian coordinates of shape (N, 3).
pdb : pd.DataFrame
PDB DataFrame with columns 'chainid', 'resseq', 'name', 'index', 'resname'.
n_aa_per_segment : int, optional
Number of amino acids per segment. Default is 18.
junction_size : int, optional
Number of residues per junction. Default is 3.
bond_cutoff : float, optional
Distance cutoff for bond detection in Angstroms. Default is 2.0.
cif_dict : dict, optional
CIF dictionary containing bond definitions per residue type.
prefer_loops : bool, optional
If True, slide junctions to prefer loop regions. Default is True.
requires_grad : bool, optional
Whether parameters should have gradients. Default is True.
dtype : torch.dtype, optional
Data type for tensors.
device : torch.device, optional
Device for tensors.
"""
[docs]
def __init__(
self,
initial_xyz: torch.Tensor,
pdb: pd.DataFrame,
n_aa_per_segment: int = 18,
junction_size: int = 3,
bond_cutoff: float = 2.0,
cif_dict: Optional[dict] = None,
prefer_loops: bool = True,
requires_grad: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
if dtype is None:
dtype = initial_xyz.dtype
if device is None:
device = initial_xyz.device
self._dtype = dtype
self._device = device
self.n_atoms = initial_xyz.shape[0]
self.bond_cutoff = bond_cutoff
self.n_aa_per_segment = n_aa_per_segment
self.junction_size = junction_size
self.prefer_loops = prefer_loops
self.cif_dict = cif_dict
initial_xyz = initial_xyz.to(dtype=dtype, device=device)
self.pdb = pdb
# Step 1: Identify backbone atoms and plan segmentation
self._identify_backbone_and_plan(pdb, initial_xyz)
# Step 2: Build segments (assigns atoms to segments, finds roots)
self._build_segments_fast(pdb)
# Step 3: Build BFS spanning trees for all segments
self._build_segment_trees_parallel(initial_xyz)
# Step 4: Extract backbone geometry for junctions
self._extract_junction_geometry(initial_xyz)
# Step 5: Extract internal coordinates
self._extract_internal_coords(initial_xyz, requires_grad)
# Step 6: Initialize junction solver
self._init_junction_solver(initial_xyz)
@property
def dtype(self):
return self._dtype
@property
def device(self):
return self._device
# =========================================================================
# Initialization methods
# =========================================================================
def _identify_backbone_and_plan(
self, pdb: pd.DataFrame, xyz: torch.Tensor
) -> None:
"""Identify backbone atoms, compute torsions, plan junctions."""
self._backbone_map = identify_backbone_atoms(pdb)
self._chain_residues = get_chain_residues(pdb)
# Compute backbone torsions for secondary structure estimation
torsions = compute_backbone_torsions(
xyz.detach(), self._backbone_map, self._chain_residues
)
ss = estimate_secondary_structure(torsions) if self.prefer_loops else None
# Plan junction placement
segments, junctions = plan_junction_placement(
self._chain_residues,
self._backbone_map,
n_aa_per_segment=self.n_aa_per_segment,
junction_size=self.junction_size,
ss=ss,
prefer_loops=self.prefer_loops,
)
self._planned_segments = segments
self._planned_junctions = junctions
self.n_junctions = len(junctions)
# Build sets of junction atom indices for quick lookup
self._junction_atom_set = set()
self._junction_backbone_set = set()
self._junction_residue_keys = set()
self._junction_data = [] # List of dicts with junction info
for junc_idx, junc_residues in enumerate(junctions):
junc_backbone = get_junction_backbone_indices(
junc_residues, self._backbone_map
)
junc_info = {
"residues": junc_residues,
"backbone": junc_backbone, # List of {N: idx, CA: idx, C: idx}
}
self._junction_data.append(junc_info)
for res_key in junc_residues:
self._junction_residue_keys.add(res_key)
bb = self._backbone_map[res_key]
for atom_name in ("N", "CA", "C"):
self._junction_backbone_set.add(bb[atom_name])
# All atoms in junction residues (including sidechains)
for res_key in junc_residues:
chainid, resseq = res_key
mask = (pdb["chainid"] == chainid) & (pdb["resseq"] == resseq)
for idx in pdb.loc[mask, "index"].values:
self._junction_atom_set.add(int(idx))
# Store junction atom indices as tensor
if self._junction_atom_set:
self.register_buffer(
"junction_atom_indices",
torch.tensor(
sorted(self._junction_atom_set),
dtype=torch.long, device=self._device,
),
)
self.register_buffer(
"junction_backbone_indices",
torch.tensor(
sorted(self._junction_backbone_set),
dtype=torch.long, device=self._device,
),
)
else:
self.register_buffer(
"junction_atom_indices",
torch.tensor([], dtype=torch.long, device=self._device),
)
self.register_buffer(
"junction_backbone_indices",
torch.tensor([], dtype=torch.long, device=self._device),
)
def _build_segments_fast(self, pdb: pd.DataFrame) -> None:
"""
Group atoms into segments based on planned segmentation.
Uses the pre-planned segments and junctions. Each junction is its own
segment for BFS purposes but its backbone torsions are excluded from
the free DOF set.
"""
device = self._device
# Build lookup for atom names
n_mask = pdb["name"].str.strip() == "N"
n_lookup = {
(c, r): i for c, r, i in zip(
pdb.loc[n_mask, "chainid"].values,
pdb.loc[n_mask, "resseq"].values,
pdb.loc[n_mask, "index"].values,
)
}
ca_mask = pdb["name"].str.strip() == "CA"
ca_lookup = {
(c, r): i for c, r, i in zip(
pdb.loc[ca_mask, "chainid"].values,
pdb.loc[ca_mask, "resseq"].values,
pdb.loc[ca_mask, "index"].values,
)
}
# Build atom lookup by residue key
residue_atoms = {}
for (chainid, resseq), group in pdb.groupby(["chainid", "resseq"]):
residue_atoms[(chainid, resseq)] = group["index"].values.tolist()
# Build segments: planned_segments + junction segments + non-protein
segments = []
segment_root_atoms = []
segment_types = [] # 'segment', 'junction', 'non_protein'
# Map junction index to segment index
self._junction_to_segment = {}
# Add planned segments (free DOF segments)
for seg_residues in self._planned_segments:
seg_atoms = []
for res_key in seg_residues:
if res_key in residue_atoms:
seg_atoms.extend(residue_atoms[res_key])
if not seg_atoms:
continue
first_res = seg_residues[0]
if first_res in n_lookup:
root = n_lookup[first_res]
elif first_res in ca_lookup:
root = ca_lookup[first_res]
else:
root = seg_atoms[0]
segments.append(seg_atoms)
segment_root_atoms.append(root)
segment_types.append("segment")
# Add junction segments
for junc_idx, junc_residues in enumerate(self._planned_junctions):
junc_atoms = []
for res_key in junc_residues:
if res_key in residue_atoms:
junc_atoms.extend(residue_atoms[res_key])
if not junc_atoms:
continue
first_res = junc_residues[0]
if first_res in n_lookup:
root = n_lookup[first_res]
elif first_res in ca_lookup:
root = ca_lookup[first_res]
else:
root = junc_atoms[0]
self._junction_to_segment[junc_idx] = len(segments)
segments.append(junc_atoms)
segment_root_atoms.append(root)
segment_types.append("junction")
# Add non-protein residues
protein_atoms = set()
for seg in self._planned_segments:
for res_key in seg:
if res_key in residue_atoms:
protein_atoms.update(residue_atoms[res_key])
for junc in self._planned_junctions:
for res_key in junc:
if res_key in residue_atoms:
protein_atoms.update(residue_atoms[res_key])
is_protein = pdb["resname"].str.upper().isin(AA_NAMES)
non_protein_residues = []
for (chainid, resseq), group in pdb.groupby(["chainid", "resseq"]):
resname = group["resname"].iloc[0]
if resname.upper() not in AA_NAMES:
atoms = group["index"].values.tolist()
non_protein_residues.append(atoms)
else:
# Check for protein residues not in any planned segment or junction
atoms = group["index"].values.tolist()
if not any(a in protein_atoms for a in atoms):
non_protein_residues.append(atoms)
for atoms in non_protein_residues:
segments.append(atoms)
segment_root_atoms.append(atoms[0])
segment_types.append("non_protein")
self.n_segments = len(segments)
self._segment_types = segment_types
# Create mapping tensors
atom_to_segment = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
max_segment_size = max(len(s) for s in segments) if segments else 0
segment_atom_indices = torch.full(
(self.n_segments, max_segment_size), -1, dtype=torch.long, device=device
)
segment_sizes = torch.zeros(self.n_segments, dtype=torch.long, device=device)
for seg_idx, atom_indices in enumerate(segments):
n_atoms_seg = len(atom_indices)
segment_sizes[seg_idx] = n_atoms_seg
segment_atom_indices[seg_idx, :n_atoms_seg] = torch.tensor(
atom_indices, dtype=torch.long, device=device
)
atom_to_segment[atom_indices] = seg_idx
self.register_buffer("atom_to_segment", atom_to_segment)
self.register_buffer("segment_atom_indices", segment_atom_indices)
self.register_buffer("segment_sizes", segment_sizes)
self.register_buffer(
"segment_roots",
torch.tensor(segment_root_atoms, dtype=torch.long, device=device),
)
self._segments = segments
def _build_segment_trees_parallel(self, xyz: torch.Tensor) -> None:
"""
Build spanning trees for all segments using fully parallel multi-source BFS.
Identical to SegmentedInternalCoordinateTensor's implementation.
"""
device = self._device
n_atoms = self.n_atoms
parent_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
grandparent_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
great_grandparent_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
atom_depth = torch.full((n_atoms,), -1, dtype=torch.long, device=device)
# Build molecular graph
if self.cif_dict is not None:
adjacency = self._build_adjacency_from_cif(xyz)
else:
distances = torch.cdist(xyz, xyz)
adjacency = (distances < self.bond_cutoff) & (distances > 0.1)
# Build compact neighbor list
neighbor_counts = adjacency.sum(dim=1).to(torch.long)
max_neighbors = neighbor_counts.max().item()
edge_src, edge_tgt = torch.where(adjacency)
sorted_indices = torch.argsort(edge_src, stable=True)
edge_src_sorted = edge_src[sorted_indices]
edge_tgt_sorted = edge_tgt[sorted_indices]
neighbors_padded = torch.full(
(n_atoms, max_neighbors), -1, dtype=torch.long, device=device
)
if edge_src_sorted.numel() > 0:
offsets = torch.zeros(n_atoms + 1, dtype=torch.long, device=device)
offsets[1:] = torch.cumsum(neighbor_counts, dim=0)
edge_global_idx = torch.arange(len(edge_src_sorted), device=device)
pos_in_group = edge_global_idx - offsets[edge_src_sorted]
neighbors_padded[edge_src_sorted, pos_in_group] = edge_tgt_sorted
# Parallel BFS from all roots
atom_depth[self.segment_roots] = 0
current_frontier = self.segment_roots.clone()
current_depth = 0
max_depth = 0
while current_frontier.numel() > 0:
frontier_neighbors = neighbors_padded[current_frontier]
candidates = frontier_neighbors.reshape(-1)
frontier_expanded = current_frontier.unsqueeze(1).expand_as(
frontier_neighbors
)
parents = frontier_expanded.reshape(-1)
valid_mask = candidates >= 0
candidates = candidates[valid_mask]
parents = parents[valid_mask]
if candidates.numel() == 0:
break
unvisited_mask = atom_depth[candidates] == -1
candidates = candidates[unvisited_mask]
parents = parents[unvisited_mask]
if candidates.numel() == 0:
break
candidate_segments = self.atom_to_segment[candidates]
parent_segments = self.atom_to_segment[parents]
same_segment_mask = candidate_segments == parent_segments
candidates = candidates[same_segment_mask]
parents = parents[same_segment_mask]
if candidates.numel() == 0:
break
unique_candidates, inverse_indices = torch.unique(
candidates, return_inverse=True
)
first_occurrence = torch.zeros(
unique_candidates.numel(), dtype=torch.long, device=device
)
first_occurrence.scatter_(0, inverse_indices, parents)
new_depth = current_depth + 1
atom_depth[unique_candidates] = new_depth
parent_indices[unique_candidates] = first_occurrence
parent_of_new = first_occurrence
gp_of_new = parent_indices[parent_of_new]
has_gp = gp_of_new >= 0
grandparent_indices[unique_candidates[has_gp]] = gp_of_new[has_gp]
ggp_of_new = grandparent_indices[parent_of_new]
has_ggp = ggp_of_new >= 0
great_grandparent_indices[unique_candidates[has_ggp]] = ggp_of_new[
has_ggp
]
max_depth = new_depth
current_frontier = unique_candidates
current_depth = new_depth
# Handle disconnected components
max_component_iterations = 100
for _ in range(max_component_iterations):
orphan_mask = (atom_depth == -1) & (self.atom_to_segment >= 0)
if not orphan_mask.any():
break
orphan_indices = torch.where(orphan_mask)[0]
orphan_segments = self.atom_to_segment[orphan_indices]
unique_orphan_segments = torch.unique(orphan_segments)
new_roots = []
for seg in unique_orphan_segments:
seg_orphans = orphan_indices[orphan_segments == seg]
if seg_orphans.numel() > 0:
new_roots.append(seg_orphans[0].item())
if not new_roots:
break
new_roots = torch.tensor(new_roots, dtype=torch.long, device=device)
atom_depth[new_roots] = 0
parent_indices[new_roots] = -1
current_frontier = new_roots
while current_frontier.numel() > 0:
frontier_neighbors = neighbors_padded[current_frontier]
candidates = frontier_neighbors.reshape(-1)
frontier_expanded = current_frontier.unsqueeze(1).expand_as(
frontier_neighbors
)
parents = frontier_expanded.reshape(-1)
valid_mask = candidates >= 0
candidates = candidates[valid_mask]
parents = parents[valid_mask]
if candidates.numel() == 0:
break
unvisited_mask = atom_depth[candidates] == -1
candidates = candidates[unvisited_mask]
parents = parents[unvisited_mask]
if candidates.numel() == 0:
break
candidate_segments = self.atom_to_segment[candidates]
parent_segments = self.atom_to_segment[parents]
same_segment_mask = candidate_segments == parent_segments
candidates = candidates[same_segment_mask]
parents = parents[same_segment_mask]
if candidates.numel() == 0:
break
unique_candidates, inverse_indices = torch.unique(
candidates, return_inverse=True
)
first_occurrence = torch.zeros(
unique_candidates.numel(), dtype=torch.long, device=device
)
first_occurrence.scatter_(0, inverse_indices, parents)
parent_depths = atom_depth[first_occurrence]
new_depths = parent_depths + 1
atom_depth[unique_candidates] = new_depths
parent_indices[unique_candidates] = first_occurrence
parent_of_new = first_occurrence
gp_of_new = parent_indices[parent_of_new]
has_gp = gp_of_new >= 0
grandparent_indices[unique_candidates[has_gp]] = gp_of_new[has_gp]
ggp_of_new = grandparent_indices[parent_of_new]
has_ggp = ggp_of_new >= 0
great_grandparent_indices[unique_candidates[has_ggp]] = ggp_of_new[
has_ggp
]
max_depth = max(max_depth, new_depths.max().item())
current_frontier = unique_candidates
# Identify secondary roots
depth0_mask = atom_depth == 0
is_primary_root = torch.zeros(n_atoms, dtype=torch.bool, device=device)
is_primary_root[self.segment_roots] = True
secondary_root_mask = depth0_mask & ~is_primary_root
secondary_root_indices = torch.where(secondary_root_mask)[0]
self.register_buffer("secondary_root_indices", secondary_root_indices)
self.max_depth = max_depth
self.register_buffer("parent_indices", parent_indices)
self.register_buffer("grandparent_indices", grandparent_indices)
self.register_buffer("great_grandparent_indices", great_grandparent_indices)
self.register_buffer("atom_depth", atom_depth)
self._build_param_indices_vectorized()
self._build_depth_indices()
self._detect_rings_vectorized(adjacency)
def _build_adjacency_from_cif(self, xyz: torch.Tensor) -> torch.Tensor:
"""Build adjacency matrix from CIF dictionary (same as parent class)."""
device = self._device
n_atoms = self.n_atoms
pdb = self.pdb
bond_idx1_list = []
bond_idx2_list = []
pdb["name_stripped"] = pdb["name"].str.strip()
residue_groups = pdb.groupby(["chainid", "resseq"])
atoms_with_cif_bonds = set()
for (chainid, resseq), group in residue_groups:
resname = group["resname"].iloc[0]
if resname not in self.cif_dict:
continue
cif_residue = self.cif_dict[resname]
if "bonds" not in cif_residue:
continue
cif_bonds = cif_residue["bonds"]
name_to_idx = dict(zip(group["name_stripped"], group["index"]))
atom1_names = cif_bonds["atom1"].str.strip().values
atom2_names = cif_bonds["atom2"].str.strip().values
for a1, a2 in zip(atom1_names, atom2_names):
if a1 in name_to_idx and a2 in name_to_idx:
idx1, idx2 = name_to_idx[a1], name_to_idx[a2]
bond_idx1_list.append(idx1)
bond_idx2_list.append(idx2)
atoms_with_cif_bonds.add(idx1)
atoms_with_cif_bonds.add(idx2)
# Peptide bonds
c_atoms = pdb[pdb["name_stripped"] == "C"].set_index(
["chainid", "resseq"]
)["index"]
n_atoms_df = pdb[pdb["name_stripped"] == "N"].set_index(
["chainid", "resseq"]
)["index"]
for chainid in pdb["chainid"].unique():
chain = pdb[pdb["chainid"] == chainid]
chain_aa = chain[chain["resname"].isin(AA_NAMES)]
resseqs = sorted(chain_aa["resseq"].unique())
for i in range(len(resseqs) - 1):
r1, r2 = resseqs[i], resseqs[i + 1]
if r2 - r1 != 1:
continue
try:
idx_c = c_atoms.loc[(chainid, r1)]
idx_n = n_atoms_df.loc[(chainid, r2)]
bond_idx1_list.append(idx_c)
bond_idx2_list.append(idx_n)
atoms_with_cif_bonds.add(idx_c)
atoms_with_cif_bonds.add(idx_n)
except KeyError:
continue
adjacency = torch.zeros(n_atoms, n_atoms, dtype=torch.bool, device=device)
if bond_idx1_list:
idx1 = torch.tensor(bond_idx1_list, dtype=torch.long, device=device)
idx2 = torch.tensor(bond_idx2_list, dtype=torch.long, device=device)
adjacency[idx1, idx2] = True
adjacency[idx2, idx1] = True
# Fallback for atoms without CIF bonds
atoms_without_cif = set(range(n_atoms)) - atoms_with_cif_bonds
if atoms_without_cif:
atoms_list = sorted(atoms_without_cif)
atoms_tensor = torch.tensor(atoms_list, dtype=torch.long, device=device)
xyz_subset = xyz[atoms_tensor]
distances_subset = torch.cdist(xyz_subset, xyz)
bond_mask = (distances_subset < self.bond_cutoff) & (
distances_subset > 0.1
)
rows, cols = torch.where(bond_mask)
adjacency[atoms_tensor[rows], cols] = True
adjacency[cols, atoms_tensor[rows]] = True
pdb.drop(columns=["name_stripped"], inplace=True, errors="ignore")
return adjacency
def _build_param_indices_vectorized(self) -> None:
"""Build mappings from atoms to parameter indices."""
device = self._device
has_bond = self.atom_depth >= 1
has_angle = self.atom_depth >= 2
has_torsion = self.atom_depth >= 3
# Build junction masks
is_junction_backbone = torch.zeros(
self.n_atoms, dtype=torch.bool, device=device
)
if self.junction_backbone_indices.numel() > 0:
is_junction_backbone[self.junction_backbone_indices] = True
is_junction_atom = torch.zeros(
self.n_atoms, dtype=torch.bool, device=device
)
if self.junction_atom_indices.numel() > 0:
is_junction_atom[self.junction_atom_indices] = True
# For torsions: exclude junction backbone atoms only (solved by Newton).
# Junction sidechain torsions (chi angles) remain free DOFs.
has_free_torsion = has_torsion & ~is_junction_backbone
self.n_bonds = has_bond.sum().item()
self.n_angles = has_angle.sum().item()
self.n_torsions = has_free_torsion.sum().item()
self.n_all_torsions = has_torsion.sum().item()
bond_param_indices = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
angle_param_indices = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
torsion_param_indices = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
if has_bond.any():
bond_cumsum = torch.cumsum(has_bond.long(), dim=0) - 1
bond_param_indices[has_bond] = bond_cumsum[has_bond]
if has_angle.any():
angle_cumsum = torch.cumsum(has_angle.long(), dim=0) - 1
angle_param_indices[has_angle] = angle_cumsum[has_angle]
if has_free_torsion.any():
torsion_cumsum = torch.cumsum(has_free_torsion.long(), dim=0) - 1
torsion_param_indices[has_free_torsion] = torsion_cumsum[has_free_torsion]
self.register_buffer("bond_param_indices", bond_param_indices)
self.register_buffer("angle_param_indices", angle_param_indices)
self.register_buffer("torsion_param_indices", torsion_param_indices)
self.register_buffer("is_junction_backbone", is_junction_backbone)
self.register_buffer("is_junction_atom", is_junction_atom)
def _build_depth_indices(self) -> None:
"""Pre-compute atom indices for each depth level.
Builds two sets of indices:
1. Regular depth indices — excludes ALL junction atoms
2. Junction sidechain depth indices — only junction non-backbone atoms,
placed after the junction solver provides backbone positions
"""
device = self._device
empty = torch.tensor([], dtype=torch.long, device=device)
# --- Regular (non-junction) depth-3+ indices ---
depth_atom_indices = []
depth_parent_indices = []
depth_gp_indices = []
depth_ggp_indices = []
depth_bond_param_indices = []
depth_angle_param_indices = []
depth_torsion_param_indices = []
for d in range(3, self.max_depth + 1):
# Exclude ALL junction atoms from regular placement
mask = (self.atom_depth == d) & ~self.is_junction_atom
if mask.any():
atom_idx = torch.where(mask)[0]
depth_atom_indices.append(atom_idx)
depth_parent_indices.append(self.parent_indices[atom_idx])
depth_gp_indices.append(self.grandparent_indices[atom_idx])
depth_ggp_indices.append(
self.great_grandparent_indices[atom_idx]
)
depth_bond_param_indices.append(self.bond_param_indices[atom_idx])
depth_angle_param_indices.append(
self.angle_param_indices[atom_idx]
)
depth_torsion_param_indices.append(
self.torsion_param_indices[atom_idx]
)
else:
depth_atom_indices.append(empty.clone())
depth_parent_indices.append(empty.clone())
depth_gp_indices.append(empty.clone())
depth_ggp_indices.append(empty.clone())
depth_bond_param_indices.append(empty.clone())
depth_angle_param_indices.append(empty.clone())
depth_torsion_param_indices.append(empty.clone())
self._depth_atom_indices = depth_atom_indices
self._depth_parent_indices = depth_parent_indices
self._depth_gp_indices = depth_gp_indices
self._depth_ggp_indices = depth_ggp_indices
self._depth_bond_param_indices = depth_bond_param_indices
self._depth_angle_param_indices = depth_angle_param_indices
self._depth_torsion_param_indices = depth_torsion_param_indices
# --- Junction sidechain depth indices ---
# These atoms are placed AFTER the junction solver provides backbone positions.
# We need depth-ordered placement for junction non-backbone atoms.
jsc_depth_atom_indices = []
jsc_depth_parent_indices = []
jsc_depth_gp_indices = []
jsc_depth_ggp_indices = []
jsc_depth_bond_param_indices = []
jsc_depth_angle_param_indices = []
# Junction sidechain atoms: in junction but not backbone
is_jsc = self.is_junction_atom & ~self.is_junction_backbone
for d in range(1, self.max_depth + 1):
mask = (self.atom_depth == d) & is_jsc
if mask.any():
atom_idx = torch.where(mask)[0]
jsc_depth_atom_indices.append(atom_idx)
jsc_depth_parent_indices.append(self.parent_indices[atom_idx])
jsc_depth_gp_indices.append(self.grandparent_indices[atom_idx])
jsc_depth_ggp_indices.append(
self.great_grandparent_indices[atom_idx]
)
jsc_depth_bond_param_indices.append(self.bond_param_indices[atom_idx])
jsc_depth_angle_param_indices.append(self.angle_param_indices[atom_idx])
else:
jsc_depth_atom_indices.append(empty.clone())
jsc_depth_parent_indices.append(empty.clone())
jsc_depth_gp_indices.append(empty.clone())
jsc_depth_ggp_indices.append(empty.clone())
jsc_depth_bond_param_indices.append(empty.clone())
jsc_depth_angle_param_indices.append(empty.clone())
self._jsc_depth_atom_indices = jsc_depth_atom_indices
self._jsc_depth_parent_indices = jsc_depth_parent_indices
self._jsc_depth_gp_indices = jsc_depth_gp_indices
self._jsc_depth_ggp_indices = jsc_depth_ggp_indices
self._jsc_depth_bond_param_indices = jsc_depth_bond_param_indices
self._jsc_depth_angle_param_indices = jsc_depth_angle_param_indices
self._jsc_min_depth = 1 # starting depth for junction sidechain placement
def _detect_rings_vectorized(self, adjacency: torch.Tensor) -> None:
"""Detect rings in the molecular graph (same as parent class)."""
device = self._device
n_atoms = self.n_atoms
ring_member_mask = torch.zeros(n_atoms, dtype=torch.bool, device=device)
ring_group_id = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
edge_i, edge_j = torch.where(adjacency.triu(diagonal=1))
is_tree_edge = (self.parent_indices[edge_i] == edge_j) | (
self.parent_indices[edge_j] == edge_i
)
same_segment = (
self.atom_to_segment[edge_i] == self.atom_to_segment[edge_j]
)
back_edge_mask = ~is_tree_edge & same_segment
back_edge_i = edge_i[back_edge_mask]
back_edge_j = edge_j[back_edge_mask]
individual_rings = []
for idx in range(len(back_edge_i)):
i = back_edge_i[idx].item()
j = back_edge_j[idx].item()
ancestors_i = set()
current = i
while current >= 0:
ancestors_i.add(current)
current = self.parent_indices[current].item()
ring_atoms = []
current = j
common_ancestor = None
while current >= 0:
ring_atoms.append(current)
if current in ancestors_i:
common_ancestor = current
break
current = self.parent_indices[current].item()
if common_ancestor is None:
continue
current = i
while current != common_ancestor:
ring_atoms.append(current)
current = self.parent_indices[current].item()
individual_rings.append(set(ring_atoms))
ring_systems = self._merge_overlapping_rings(individual_rings)
ring_anchors = []
ring_members_list = []
for ring_idx, ring_atoms_set in enumerate(ring_systems):
ring_atoms = list(ring_atoms_set)
for atom in ring_atoms:
ring_member_mask[atom] = True
ring_group_id[atom] = ring_idx
depths = self.atom_depth[ring_atoms]
anchor_idx = ring_atoms[depths.argmin().item()]
ring_anchors.append(anchor_idx)
ring_members_list.append(ring_atoms)
self.register_buffer("ring_member_mask", ring_member_mask)
self.register_buffer("ring_group_id", ring_group_id)
self.n_rings = len(ring_anchors)
if self.n_rings > 0:
self.register_buffer(
"ring_anchor_atoms",
torch.tensor(ring_anchors, dtype=torch.long, device=device),
)
max_ring_size = max(len(r) for r in ring_members_list)
ring_members_tensor = torch.full(
(self.n_rings, max_ring_size), -1, dtype=torch.long, device=device
)
ring_sizes = torch.zeros(
self.n_rings, dtype=torch.long, device=device
)
for i, members in enumerate(ring_members_list):
ring_sizes[i] = len(members)
ring_members_tensor[i, : len(members)] = torch.tensor(
members, dtype=torch.long, device=device
)
self.register_buffer("ring_members", ring_members_tensor)
self.register_buffer("ring_sizes", ring_sizes)
else:
self.register_buffer(
"ring_anchor_atoms",
torch.tensor([], dtype=torch.long, device=device),
)
self.register_buffer(
"ring_members",
torch.tensor([], dtype=torch.long, device=device).reshape(0, 0),
)
self.register_buffer(
"ring_sizes",
torch.tensor([], dtype=torch.long, device=device),
)
def _merge_overlapping_rings(self, rings: list) -> list:
"""Merge overlapping rings using union-find."""
if not rings:
return []
n_rings = len(rings)
parent = list(range(n_rings))
rank = [0] * n_rings
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px == py:
return
if rank[px] < rank[py]:
px, py = py, px
parent[py] = px
if rank[px] == rank[py]:
rank[px] += 1
for i in range(n_rings):
for j in range(i + 1, n_rings):
if rings[i] & rings[j]:
union(i, j)
groups = defaultdict(list)
for i in range(n_rings):
groups[find(i)].append(i)
ring_systems = []
for ring_indices in groups.values():
merged = set()
for ri in ring_indices:
merged |= rings[ri]
ring_systems.append(merged)
return ring_systems
def _extract_junction_geometry(self, xyz: torch.Tensor) -> None:
"""
Extract fixed backbone geometry for junctions.
Stores bond lengths, bond angles (in NeRF placement order), omega angles,
psi_prev, and target C positions. All stored as buffers (fixed, not parameters).
NeRF placement order per junction residue:
- Bond lengths: [C_prev-N, N-CA, CA-C]
- Bond angles: [angle_at_C_prev_for_N, angle_at_N_for_CA, angle_at_CA_for_C]
"""
device = self._device
dtype = self._dtype
if self.n_junctions == 0:
self.register_buffer(
"junction_bond_lengths",
torch.zeros(0, 3 * self.junction_size, dtype=dtype, device=device),
)
self.register_buffer(
"junction_nerf_angles",
torch.zeros(0, 3 * self.junction_size, dtype=dtype, device=device),
)
self.register_buffer(
"junction_omega",
torch.zeros(0, self.junction_size, dtype=dtype, device=device),
)
self.register_buffer(
"junction_psi_prev",
torch.zeros(0, dtype=dtype, device=device),
)
self.register_buffer(
"junction_post_bond_length",
torch.zeros(0, dtype=dtype, device=device),
)
self.register_buffer(
"junction_post_bond_angle",
torch.zeros(0, dtype=dtype, device=device),
)
self._junction_post_n_indices = []
return
all_bond_lengths = []
all_nerf_angles = []
all_omega = []
all_psi_prev = []
all_post_bl = []
all_post_ba = []
post_n_indices = [] # Global index of post-junction N atom per junction
for junc_idx, junc_info in enumerate(self._junction_data):
junc_residues = junc_info["residues"]
junc_backbone = junc_info["backbone"]
bl_list = []
nerf_ang_list = []
omega_list = []
pre_seg_idx = self._find_pre_junction_segment(junc_idx)
for res_i, res_key in enumerate(junc_residues):
bb = junc_backbone[res_i]
n_pos = xyz[bb["N"]]
ca_pos = xyz[bb["CA"]]
c_pos = xyz[bb["C"]]
# Previous residue atoms
if res_i == 0:
if pre_seg_idx is not None:
prev_res = self._planned_segments[pre_seg_idx][-1]
if prev_res in self._backbone_map:
prev_bb = self._backbone_map[prev_res]
prev_n = xyz[prev_bb["N"]]
prev_ca = xyz[prev_bb["CA"]]
prev_c = xyz[prev_bb["C"]]
else:
prev_n = n_pos - torch.tensor([2.5, 0, 0], dtype=dtype, device=device)
prev_ca = n_pos - torch.tensor([1.3, 0, 0], dtype=dtype, device=device)
prev_c = n_pos - torch.tensor([0.5, 0, 0], dtype=dtype, device=device)
else:
prev_n = n_pos - torch.tensor([2.5, 0, 0], dtype=dtype, device=device)
prev_ca = n_pos - torch.tensor([1.3, 0, 0], dtype=dtype, device=device)
prev_c = n_pos - torch.tensor([0.5, 0, 0], dtype=dtype, device=device)
else:
prev_bb_idx = junc_backbone[res_i - 1]
prev_n = xyz[prev_bb_idx["N"]]
prev_ca = xyz[prev_bb_idx["CA"]]
prev_c = xyz[prev_bb_idx["C"]]
# Bond lengths: C_prev-N, N-CA, CA-C
bl_list.append(torch.linalg.norm(n_pos - prev_c))
bl_list.append(torch.linalg.norm(ca_pos - n_pos))
bl_list.append(torch.linalg.norm(c_pos - ca_pos))
# NeRF angles at pivot atoms:
# 1. Angle at C_prev for placing N: angle(CA_prev, C_prev, N)
nerf_ang_list.append(_compute_angle(prev_ca, prev_c, n_pos))
# 2. Angle at N for placing CA: angle(C_prev, N, CA)
nerf_ang_list.append(_compute_angle(prev_c, n_pos, ca_pos))
# 3. Angle at CA for placing C: angle(N, CA, C)
nerf_ang_list.append(_compute_angle(n_pos, ca_pos, c_pos))
# Omega: dihedral(CA_prev, C_prev, N, CA)
omega_list.append(_compute_torsion(prev_ca, prev_c, n_pos, ca_pos))
# psi_prev: dihedral(N_prev, CA_prev, C_prev, N_first_junction)
first_bb = junc_backbone[0]
first_n = xyz[first_bb["N"]]
if pre_seg_idx is not None:
prev_res = self._planned_segments[pre_seg_idx][-1]
if prev_res in self._backbone_map:
prev_bb = self._backbone_map[prev_res]
psi_prev = _compute_torsion(
xyz[prev_bb["N"]], xyz[prev_bb["CA"]], xyz[prev_bb["C"]], first_n
)
else:
psi_prev = torch.tensor(0.0, dtype=dtype, device=device)
else:
psi_prev = torch.tensor(0.0, dtype=dtype, device=device)
# Post-junction geometry: C(last_junc) -> N(first_post_segment)
# for extending the FK to target the post-junction N dynamically.
last_bb = junc_backbone[-1]
last_c = xyz[last_bb["C"]]
last_ca = xyz[last_bb["CA"]]
last_n = xyz[last_bb["N"]]
post_seg_idx = self._find_post_junction_segment(junc_idx)
if post_seg_idx is not None and post_seg_idx < len(self._planned_segments):
post_res = self._planned_segments[post_seg_idx][0]
if post_res in self._backbone_map:
post_n_idx = self._backbone_map[post_res]["N"]
post_n_pos = xyz[post_n_idx]
post_bl = torch.linalg.norm(post_n_pos - last_c)
post_ba = _compute_angle(last_ca, last_c, post_n_pos)
post_n_indices.append(post_n_idx)
else:
# Fallback: use typical peptide bond
post_bl = torch.tensor(1.329, dtype=dtype, device=device)
post_ba = torch.tensor(2.028, dtype=dtype, device=device) # ~116 deg
post_n_indices.append(-1)
else:
# Last junction with no post-segment: use typical values
post_bl = torch.tensor(1.329, dtype=dtype, device=device)
post_ba = torch.tensor(2.028, dtype=dtype, device=device)
post_n_indices.append(-1)
all_bond_lengths.append(torch.stack(bl_list).to(dtype=dtype, device=device))
all_nerf_angles.append(torch.stack(nerf_ang_list).to(dtype=dtype, device=device))
all_omega.append(torch.stack(omega_list).to(dtype=dtype, device=device))
all_psi_prev.append(psi_prev.to(dtype=dtype, device=device))
all_post_bl.append(post_bl.to(dtype=dtype, device=device))
all_post_ba.append(post_ba.to(dtype=dtype, device=device))
self.register_buffer("junction_bond_lengths", torch.stack(all_bond_lengths))
self.register_buffer("junction_nerf_angles", torch.stack(all_nerf_angles))
self.register_buffer("junction_omega", torch.stack(all_omega))
self.register_buffer("junction_psi_prev", torch.stack(all_psi_prev))
self.register_buffer("junction_post_bond_length", torch.stack(all_post_bl))
self.register_buffer("junction_post_bond_angle", torch.stack(all_post_ba))
self._junction_post_n_indices = post_n_indices
def _find_pre_junction_segment(self, junc_idx: int) -> Optional[int]:
"""Find the index (into _planned_segments) of the segment before this junction."""
# Junction junc_idx connects segment junc_idx to segment junc_idx+1
if junc_idx < len(self._planned_segments):
return junc_idx
return None
def _find_post_junction_segment(self, junc_idx: int) -> Optional[int]:
"""Find the index (into _planned_segments) of the segment after this junction."""
if junc_idx + 1 < len(self._planned_segments):
return junc_idx + 1
return None
def _extract_internal_coords(
self, xyz: torch.Tensor, requires_grad: bool = True
) -> None:
"""
Extract internal coordinates from Cartesian coordinates.
Junction backbone torsions are excluded from the free torsion parameters.
"""
device = self._device
dtype = self._dtype
# Bond lengths (all atoms with depth >= 1)
bond_mask = self.atom_depth >= 1
if bond_mask.any():
child_xyz = xyz[bond_mask]
parent_xyz = xyz[self.parent_indices[bond_mask]]
bond_lengths = torch.linalg.norm(child_xyz - parent_xyz, dim=-1)
else:
bond_lengths = torch.tensor([], dtype=dtype, device=device)
# Angles (all atoms with depth >= 2)
angle_mask = self.atom_depth >= 2
if angle_mask.any():
child_xyz = xyz[angle_mask]
parent_xyz = xyz[self.parent_indices[angle_mask]]
grandparent_xyz = xyz[self.grandparent_indices[angle_mask]]
v1 = child_xyz - parent_xyz
v2 = grandparent_xyz - parent_xyz
cos_angles = torch.sum(v1 * v2, dim=-1) / (
torch.linalg.norm(v1, dim=-1)
* torch.linalg.norm(v2, dim=-1)
+ 1e-10
)
cos_angles = torch.clamp(cos_angles, -1.0, 1.0)
angles = torch.acos(cos_angles)
else:
angles = torch.tensor([], dtype=dtype, device=device)
# Torsions (depth >= 3, EXCLUDING junction backbone atoms only)
torsion_mask = (self.atom_depth >= 3) & ~self.is_junction_backbone
if torsion_mask.any():
child_xyz = xyz[torsion_mask]
parent_xyz = xyz[self.parent_indices[torsion_mask]]
grandparent_xyz = xyz[self.grandparent_indices[torsion_mask]]
great_grandparent_xyz = xyz[
self.great_grandparent_indices[torsion_mask]
]
torsions = self._compute_torsion_angles(
great_grandparent_xyz, grandparent_xyz, parent_xyz, child_xyz
)
else:
torsions = torch.tensor([], dtype=dtype, device=device)
# Segment positions and orientations
segment_positions = xyz[self.segment_roots].clone()
if self.secondary_root_indices.numel() > 0:
secondary_root_positions = xyz[self.secondary_root_indices].clone()
else:
secondary_root_positions = torch.zeros(0, 3, dtype=dtype, device=device)
segment_orientations = torch.zeros(
self.n_segments, 3, dtype=dtype, device=device
)
# Ring local coordinates
if self.n_rings > 0:
ring_local_coords = self._extract_ring_local_coords_vectorized(xyz)
self.register_buffer("ring_local_coords", ring_local_coords)
else:
self.register_buffer(
"ring_local_coords",
torch.tensor([], dtype=dtype, device=device).reshape(0, 0, 3),
)
# Shallow atom references
self._setup_shallow_atom_references_vectorized(xyz)
# Register parameters (free DOFs)
self.bond_lengths = nn.Parameter(
bond_lengths.clone(), requires_grad=requires_grad
)
self.angles = nn.Parameter(angles.clone(), requires_grad=requires_grad)
self.torsions = nn.Parameter(
torsions.clone(), requires_grad=requires_grad
)
self.segment_positions = nn.Parameter(
segment_positions.clone(), requires_grad=requires_grad
)
self.segment_orientations = nn.Parameter(
segment_orientations.clone(), requires_grad=requires_grad
)
self.secondary_root_positions = nn.Parameter(
secondary_root_positions.clone(), requires_grad=requires_grad
)
# Initialize refinable mask
self.register_buffer(
"refinable_mask",
torch.ones(self.n_atoms, dtype=torch.bool, device=device),
)
self.register_buffer("fixed_xyz", xyz.clone())
def _init_junction_solver(self, xyz: torch.Tensor) -> None:
"""Initialize the junction solver with warm-start from initial coordinates."""
if self.n_junctions == 0:
self.junction_solver = JunctionSolver(
n_junctions=0,
junction_size=self.junction_size,
initial_phi_psi=torch.zeros(
0, 2 * self.junction_size,
dtype=self._dtype, device=self._device,
),
dtype=self._dtype,
device=self._device,
)
return
# Extract initial phi/psi for each junction from coordinates
torsions = compute_backbone_torsions(
xyz.detach(), self._backbone_map, self._chain_residues
)
initial_phi_psi = []
for junc_info in self._junction_data:
junc_residues = junc_info["residues"]
pp = []
for res_key in junc_residues:
if res_key in torsions:
phi = torsions[res_key]["phi"]
psi = torsions[res_key]["psi"]
pp.append(phi if not np.isnan(phi) else 0.0)
pp.append(psi if not np.isnan(psi) else 0.0)
else:
pp.extend([0.0, 0.0])
initial_phi_psi.append(
torch.tensor(pp, dtype=self._dtype, device=self._device)
)
initial_phi_psi = torch.stack(initial_phi_psi) # (J, 2*junction_size)
self.junction_solver = JunctionSolver(
n_junctions=self.n_junctions,
junction_size=self.junction_size,
initial_phi_psi=initial_phi_psi,
max_iter=50,
tol=1e-4,
tikhonov_eps=1e-6,
dtype=self._dtype,
device=self._device,
)
# Build mapping from junction backbone atom indices to positions in
# the junction_solver output
self._build_junction_atom_mapping()
def _build_junction_atom_mapping(self) -> None:
"""Build mapping from junction backbone atoms to solver output positions."""
device = self._device
# For each junction, the solver outputs 3*junction_size backbone atoms
# in order: [N0, CA0, C0, N1, CA1, C1, ...]
# We need to map these to global atom indices
junc_global_indices = [] # Global atom index
junc_output_indices = [] # (junction_idx, position_in_output)
for junc_idx, junc_info in enumerate(self._junction_data):
for res_i, res_key in enumerate(junc_info["residues"]):
bb = junc_info["backbone"][res_i]
for atom_i, atom_name in enumerate(["N", "CA", "C"]):
global_idx = bb[atom_name]
output_pos = res_i * 3 + atom_i
junc_global_indices.append(global_idx)
junc_output_indices.append((junc_idx, output_pos))
if junc_global_indices:
self.register_buffer(
"_junc_global_indices",
torch.tensor(junc_global_indices, dtype=torch.long, device=device),
)
self._junc_output_indices = junc_output_indices
else:
self.register_buffer(
"_junc_global_indices",
torch.tensor([], dtype=torch.long, device=device),
)
self._junc_output_indices = []
# Also build mapping for junction sidechain atoms
# These are placed by NeRF from the junction backbone positions
# They're already in the depth indices as regular NeRF atoms
# (unless they're junction backbone atoms, which are excluded)
def _setup_shallow_atom_references_vectorized(
self, xyz: torch.Tensor
) -> None:
"""Store reference directions for depth-1 and depth-2 atoms."""
device = self._device
dtype = self._dtype
# Depth-1 atoms
depth1_mask = self.atom_depth == 1
n_depth1 = depth1_mask.sum().item()
if n_depth1 > 0:
depth1_indices = torch.where(depth1_mask)[0]
parent_idx = self.parent_indices[depth1_indices]
directions = xyz[depth1_indices] - xyz[parent_idx]
norms = torch.linalg.norm(
directions, dim=-1, keepdim=True
).clamp(min=1e-10)
depth1_dirs = directions / norms
depth1_atom_to_dir_idx = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
depth1_atom_to_dir_idx[depth1_indices] = torch.arange(
n_depth1, dtype=torch.long, device=device
)
self.register_buffer("depth1_dirs", depth1_dirs)
self.register_buffer(
"depth1_atom_to_dir_idx", depth1_atom_to_dir_idx
)
else:
self.register_buffer(
"depth1_dirs", torch.zeros(0, 3, dtype=dtype, device=device)
)
self.register_buffer(
"depth1_atom_to_dir_idx",
torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
),
)
# Depth-2 atoms
depth2_mask = self.atom_depth == 2
n_depth2 = depth2_mask.sum().item()
if n_depth2 > 0:
depth2_indices = torch.where(depth2_mask)[0]
parent_idx = self.parent_indices[depth2_indices]
grandparent_idx = self.grandparent_indices[depth2_indices]
v_bond = xyz[parent_idx] - xyz[grandparent_idx]
v_bond_norm = torch.linalg.norm(
v_bond, dim=-1, keepdim=True
).clamp(min=1e-10)
v_bond = v_bond / v_bond_norm
v_child = xyz[depth2_indices] - xyz[parent_idx]
proj = torch.sum(v_child * v_bond, dim=-1, keepdim=True)
v_perp = v_child - proj * v_bond
v_perp_norm = torch.linalg.norm(v_perp, dim=-1, keepdim=True)
collinear = v_perp_norm.squeeze(-1) < 1e-10
if collinear.any():
default_perp = torch.zeros_like(v_perp)
nearly_x = torch.abs(v_bond[:, 0]) >= 0.9
default_perp[~nearly_x, 0] = 1.0
default_perp[nearly_x, 1] = 1.0
proj_default = torch.sum(
default_perp * v_bond, dim=-1, keepdim=True
)
default_perp = default_perp - proj_default * v_bond
default_perp = default_perp / torch.linalg.norm(
default_perp, dim=-1, keepdim=True
).clamp(min=1e-10)
v_perp[collinear] = default_perp[collinear]
v_perp_norm[collinear] = 1.0
depth2_perps = v_perp / v_perp_norm.clamp(min=1e-10)
depth2_atom_to_perp_idx = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
depth2_atom_to_perp_idx[depth2_indices] = torch.arange(
n_depth2, dtype=torch.long, device=device
)
self.register_buffer("depth2_perps", depth2_perps)
self.register_buffer(
"depth2_atom_to_perp_idx", depth2_atom_to_perp_idx
)
else:
self.register_buffer(
"depth2_perps", torch.zeros(0, 3, dtype=dtype, device=device)
)
self.register_buffer(
"depth2_atom_to_perp_idx",
torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
),
)
def _extract_ring_local_coords_vectorized(
self, xyz: torch.Tensor
) -> torch.Tensor:
"""Extract local coordinates for ring atoms."""
device = self._device
dtype = self._dtype
max_ring_size = (
self.ring_members.shape[1] if self.n_rings > 0 else 0
)
ring_local_coords = torch.zeros(
self.n_rings, max_ring_size, 3, dtype=dtype, device=device
)
if self.n_rings == 0:
return ring_local_coords
anchor_pos = xyz[self.ring_anchor_atoms]
parent_idx = self.parent_indices[self.ring_anchor_atoms]
valid_parent = parent_idx >= 0
parent_pos = torch.zeros_like(anchor_pos)
parent_pos[valid_parent] = xyz[parent_idx[valid_parent]]
z_axis = anchor_pos - parent_pos
z_norm = torch.linalg.norm(z_axis, dim=-1, keepdim=True)
z_axis = torch.where(
z_norm > 1e-10,
z_axis / z_norm.clamp(min=1e-10),
torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device),
)
z_axis[~valid_parent] = torch.tensor(
[0.0, 0.0, 1.0], dtype=dtype, device=device
)
x_base = torch.zeros_like(z_axis)
x_base[:, 0] = 1.0
nearly_x = torch.abs(z_axis[:, 0]) >= 0.9
x_base[nearly_x, 0] = 0.0
x_base[nearly_x, 1] = 1.0
x_axis = (
x_base - (x_base * z_axis).sum(dim=-1, keepdim=True) * z_axis
)
x_axis = x_axis / torch.linalg.norm(
x_axis, dim=-1, keepdim=True
).clamp(min=1e-10)
y_axis = torch.linalg.cross(z_axis, x_axis)
R = torch.stack([x_axis, y_axis, z_axis], dim=-1)
valid_mask = self.ring_members >= 0
safe_members = self.ring_members.clamp(min=0)
member_pos = xyz[safe_members]
offsets = member_pos - anchor_pos.unsqueeze(1)
ring_local_coords = torch.bmm(offsets, R)
ring_local_coords = ring_local_coords * valid_mask.unsqueeze(-1)
return ring_local_coords
@staticmethod
def _compute_torsion_angles(
p1: torch.Tensor,
p2: torch.Tensor,
p3: torch.Tensor,
p4: torch.Tensor,
) -> torch.Tensor:
"""Compute torsion angles for batched positions."""
b1 = p2 - p1
b2 = p3 - p2
b3 = p4 - p3
n1 = torch.linalg.cross(b1, b2)
n2 = torch.linalg.cross(b2, b3)
n1 = n1 / torch.linalg.norm(n1, dim=-1, keepdim=True).clamp(min=1e-10)
n2 = n2 / torch.linalg.norm(n2, dim=-1, keepdim=True).clamp(min=1e-10)
b2_norm = b2 / torch.linalg.norm(b2, dim=-1, keepdim=True).clamp(
min=1e-10
)
m1 = torch.linalg.cross(n1, b2_norm)
x = torch.sum(n1 * n2, dim=-1)
y = torch.sum(m1 * n2, dim=-1)
return torch.atan2(y, x)
# =========================================================================
# Forward pass
# =========================================================================
[docs]
def forward(self) -> torch.Tensor:
"""
Reconstruct Cartesian xyz from internal coordinates with chain closure.
Steps:
1. Place segment atoms (existing NeRF pipeline)
2. Solve junction closures (Newton + IFT)
3. Place junction backbone atoms
4. Place junction sidechain atoms
5. Apply frozen overlay
Returns
-------
torch.Tensor
Reconstructed Cartesian coordinates of shape (N, 3).
"""
xyz = torch.zeros(
self.n_atoms, 3, dtype=self._dtype, device=self._device
)
# Get rotation matrices for all segments (batched)
R_matrices = rotation_matrix_euler_zyz(self.segment_orientations)
# Place segment root atoms
xyz[self.segment_roots] = self.segment_positions
# Place secondary roots
if self.secondary_root_indices.numel() > 0:
xyz[self.secondary_root_indices] = self.secondary_root_positions
# Place depth-1 atoms
xyz = self._place_depth1_atoms(xyz, R_matrices)
# Place depth-2 atoms
xyz = self._place_depth2_atoms(xyz, R_matrices)
# Place depth-3+ atoms (excluding ALL junction atoms)
for depth_idx in range(len(self._depth_atom_indices)):
xyz = self._place_atoms_at_depth_fast(xyz, depth_idx)
# Place rigid ring atoms (non-junction only; junction rings placed below)
xyz = self._place_rigid_rings(xyz)
# Solve junction closures, place junction backbone, then sidechain atoms
if self.n_junctions > 0:
xyz = self._solve_and_place_junctions(xyz)
xyz = self._place_junction_sidechain_atoms(xyz)
# Apply frozen coordinates
if not self.refinable_mask.all():
frozen_mask = ~self.refinable_mask
xyz[frozen_mask] = self.fixed_xyz[frozen_mask]
return xyz
def _place_depth1_atoms(
self, xyz: torch.Tensor, R_matrices: torch.Tensor
) -> torch.Tensor:
"""Place atoms at depth 1 (excluding junction atoms)."""
mask = (self.atom_depth == 1) & ~self.is_junction_atom
if not mask.any():
return xyz
xyz = xyz.clone()
atom_indices = torch.where(mask)[0]
seg_ids = self.atom_to_segment[atom_indices]
parent_idx = self.parent_indices[atom_indices]
parent_positions = xyz[parent_idx]
R = R_matrices[seg_ids]
dir_indices = self.depth1_atom_to_dir_idx[atom_indices]
base_dirs = self.depth1_dirs[dir_indices]
rotated_dirs = torch.bmm(R, base_dirs.unsqueeze(-1)).squeeze(-1)
bond_idx = self.bond_param_indices[mask]
bond_lengths = self.bond_lengths[bond_idx]
new_positions = parent_positions + bond_lengths.unsqueeze(-1) * rotated_dirs
xyz[mask] = new_positions
return xyz
def _place_depth2_atoms(
self, xyz: torch.Tensor, R_matrices: torch.Tensor
) -> torch.Tensor:
"""Place atoms at depth 2 (excluding junction atoms)."""
mask = (self.atom_depth == 2) & ~self.is_junction_atom
if not mask.any():
return xyz
xyz = xyz.clone()
parent_idx = self.parent_indices[mask]
grandparent_idx = self.grandparent_indices[mask]
parent_pos = xyz[parent_idx]
grandparent_pos = xyz[grandparent_idx]
bond_idx = self.bond_param_indices[mask]
angle_idx = self.angle_param_indices[mask]
bond_lengths = self.bond_lengths[bond_idx]
angles = self.angles[angle_idx]
bc = parent_pos - grandparent_pos
bc_norm = torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
bc_unit = bc / bc_norm
atom_indices = torch.where(mask)[0]
perp_idx = self.depth2_atom_to_perp_idx[atom_indices]
ref_perp = self.depth2_perps[perp_idx]
seg_ids = self.atom_to_segment[atom_indices]
R = R_matrices[seg_ids]
rotated_ref = torch.bmm(R, ref_perp.unsqueeze(-1)).squeeze(-1)
proj = torch.sum(rotated_ref * bc_unit, dim=-1, keepdim=True)
rotated_ref = rotated_ref - proj * bc_unit
ref_norm = torch.linalg.norm(
rotated_ref, dim=-1, keepdim=True
).clamp(min=1e-10)
rotated_ref = rotated_ref / ref_norm
theta_internal = torch.pi - angles
dx = bond_lengths * torch.cos(theta_internal)
dy = bond_lengths * torch.sin(theta_internal)
new_positions = (
parent_pos
+ dx.unsqueeze(-1) * bc_unit
+ dy.unsqueeze(-1) * rotated_ref
)
xyz[mask] = new_positions
return xyz
def _place_atoms_at_depth_fast(
self, xyz: torch.Tensor, depth_idx: int
) -> torch.Tensor:
"""Place atoms at a depth level using pre-computed indices."""
atom_idx = self._depth_atom_indices[depth_idx]
if atom_idx.numel() == 0:
return xyz
xyz = xyz.clone()
parent_idx = self._depth_parent_indices[depth_idx]
gp_idx = self._depth_gp_indices[depth_idx]
ggp_idx = self._depth_ggp_indices[depth_idx]
p1 = xyz[ggp_idx]
p2 = xyz[gp_idx]
p3 = xyz[parent_idx]
bond_idx = self._depth_bond_param_indices[depth_idx]
angle_idx = self._depth_angle_param_indices[depth_idx]
torsion_idx = self._depth_torsion_param_indices[depth_idx]
d = self.bond_lengths[bond_idx]
theta = self.angles[angle_idx]
phi = self.torsions[torsion_idx]
bc = p3 - p2
bc = bc / torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
ab = p2 - p1
n = torch.linalg.cross(ab, bc)
n = n / torch.linalg.norm(n, dim=-1, keepdim=True).clamp(min=1e-10)
m = torch.linalg.cross(n, bc)
theta_internal = torch.pi - theta
sin_theta = torch.sin(theta_internal)
cos_theta = torch.cos(theta_internal)
dx = d * cos_theta
dy = d * sin_theta * torch.cos(phi)
dz = d * sin_theta * torch.sin(phi)
new_pos = (
p3
+ dx.unsqueeze(-1) * bc
+ dy.unsqueeze(-1) * m
- dz.unsqueeze(-1) * n
)
xyz[atom_idx] = new_pos
return xyz
def _place_rigid_rings(self, xyz: torch.Tensor) -> torch.Tensor:
"""Place ring atoms as rigid groups."""
if self.n_rings == 0:
return xyz
xyz = xyz.clone()
anchor_pos = xyz[self.ring_anchor_atoms]
parent_idx = self.parent_indices[self.ring_anchor_atoms]
valid_parent_mask = parent_idx >= 0
parent_pos = torch.zeros_like(anchor_pos)
parent_pos[valid_parent_mask] = xyz[parent_idx[valid_parent_mask]]
z_axis = anchor_pos - parent_pos
z_norm = torch.linalg.norm(z_axis, dim=-1, keepdim=True)
z_axis = torch.where(
z_norm > 1e-10,
z_axis / z_norm.clamp(min=1e-10),
torch.tensor(
[0.0, 0.0, 1.0], dtype=self._dtype, device=self._device
),
)
z_axis[~valid_parent_mask] = torch.tensor(
[0.0, 0.0, 1.0], dtype=self._dtype, device=self._device
)
x_base = torch.zeros_like(z_axis)
x_base[:, 0] = 1.0
nearly_x = torch.abs(z_axis[:, 0]) >= 0.9
x_base[nearly_x, 0] = 0.0
x_base[nearly_x, 1] = 1.0
x_axis = (
x_base - (x_base * z_axis).sum(dim=-1, keepdim=True) * z_axis
)
x_axis = x_axis / torch.linalg.norm(
x_axis, dim=-1, keepdim=True
).clamp(min=1e-10)
y_axis = torch.linalg.cross(z_axis, x_axis)
R = torch.stack([x_axis, y_axis, z_axis], dim=-1)
global_offsets = torch.bmm(
self.ring_local_coords, R.transpose(-1, -2)
)
global_pos = global_offsets + anchor_pos.unsqueeze(1)
valid_mask = self.ring_members >= 0
anchor_expanded = self.ring_anchor_atoms.unsqueeze(1)
valid_mask = valid_mask & (self.ring_members != anchor_expanded)
ring_indices, local_indices = torch.where(valid_mask)
atom_indices = self.ring_members[ring_indices, local_indices]
positions = global_pos[ring_indices, local_indices]
xyz[atom_indices] = positions
return xyz
def _solve_and_place_junctions(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Solve junction closures and place junction backbone atoms.
Extracts (N, CA, C) of the last pre-junction residue as starting
positions for the NeRF-based FK, uses stored target C positions,
runs the Newton solver, and scatters backbone atoms into xyz.
"""
xyz = xyz.clone()
# Build start positions for each junction: N, CA, C of last pre-junction residue
p1_list = [] # N positions
p2_list = [] # CA positions
p3_list = [] # C positions
for junc_idx in range(self.n_junctions):
pre_seg_idx = self._find_pre_junction_segment(junc_idx)
if pre_seg_idx is not None and pre_seg_idx < len(self._planned_segments):
last_res = self._planned_segments[pre_seg_idx][-1]
if last_res in self._backbone_map:
bb = self._backbone_map[last_res]
p1_list.append(xyz[bb["N"]])
p2_list.append(xyz[bb["CA"]])
p3_list.append(xyz[bb["C"]])
else:
# Fallback: use identity-like positions
p1_list.append(torch.zeros(3, dtype=self._dtype, device=self._device))
p2_list.append(torch.tensor([1.0, 0.0, 0.0], dtype=self._dtype, device=self._device))
p3_list.append(torch.tensor([2.0, 0.0, 0.0], dtype=self._dtype, device=self._device))
else:
p1_list.append(torch.zeros(3, dtype=self._dtype, device=self._device))
p2_list.append(torch.tensor([1.0, 0.0, 0.0], dtype=self._dtype, device=self._device))
p3_list.append(torch.tensor([2.0, 0.0, 0.0], dtype=self._dtype, device=self._device))
p1_start = torch.stack(p1_list) # (J, 3)
p2_start = torch.stack(p2_list) # (J, 3)
p3_start = torch.stack(p3_list) # (J, 3)
# Build dynamic targets: N of first post-junction residue
target_list = []
for junc_idx in range(self.n_junctions):
n_idx = self._junction_post_n_indices[junc_idx]
if n_idx >= 0:
target_list.append(xyz[n_idx])
else:
# No post-junction segment — use a dummy target
target_list.append(torch.zeros(3, dtype=self._dtype, device=self._device))
target_n = torch.stack(target_list) # (J, 3)
# Solve junction closures
phi_psi, backbone_xyz = self.junction_solver(
p1_start, p2_start, p3_start,
target_n,
self.junction_bond_lengths,
self.junction_nerf_angles,
self.junction_omega,
self.junction_psi_prev,
self.junction_post_bond_length,
self.junction_post_bond_angle,
)
# Scatter junction backbone atoms into xyz
for i, (junc_idx, output_pos) in enumerate(self._junc_output_indices):
global_idx = self._junc_global_indices[i]
xyz[global_idx] = backbone_xyz[junc_idx, output_pos]
return xyz
def _place_junction_sidechain_atoms(
self, xyz: torch.Tensor
) -> torch.Tensor:
"""
Place junction sidechain atoms using NeRF from solved backbone positions.
Called AFTER the junction solver has placed backbone atoms. Uses the
pre-built junction sidechain depth indices to place atoms in depth order,
using the same bond length/angle parameters and stored reference geometry.
"""
xyz = xyz.clone()
for depth_list_idx in range(len(self._jsc_depth_atom_indices)):
atom_idx = self._jsc_depth_atom_indices[depth_list_idx]
if atom_idx.numel() == 0:
continue
actual_depth = depth_list_idx + self._jsc_min_depth
if actual_depth == 1:
# Depth-1: use reference directions (like _place_depth1_atoms)
parent_idx = self._jsc_depth_parent_indices[depth_list_idx]
parent_pos = xyz[parent_idx]
bond_idx = self._jsc_depth_bond_param_indices[depth_list_idx]
bond_lengths = self.bond_lengths[bond_idx]
seg_ids = self.atom_to_segment[atom_idx]
R = rotation_matrix_euler_zyz(self.segment_orientations)[seg_ids]
dir_indices = self.depth1_atom_to_dir_idx[atom_idx]
base_dirs = self.depth1_dirs[dir_indices]
rotated_dirs = torch.bmm(R, base_dirs.unsqueeze(-1)).squeeze(-1)
new_pos = parent_pos + bond_lengths.unsqueeze(-1) * rotated_dirs
xyz[atom_idx] = new_pos
elif actual_depth == 2:
# Depth-2: use reference perpendiculars (like _place_depth2_atoms)
parent_idx = self._jsc_depth_parent_indices[depth_list_idx]
gp_idx = self._jsc_depth_gp_indices[depth_list_idx]
parent_pos = xyz[parent_idx]
gp_pos = xyz[gp_idx]
bond_idx = self._jsc_depth_bond_param_indices[depth_list_idx]
angle_idx = self._jsc_depth_angle_param_indices[depth_list_idx]
bond_lengths = self.bond_lengths[bond_idx]
angles = self.angles[angle_idx]
bc = parent_pos - gp_pos
bc_norm = torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
bc_unit = bc / bc_norm
perp_idx = self.depth2_atom_to_perp_idx[atom_idx]
ref_perp = self.depth2_perps[perp_idx]
seg_ids = self.atom_to_segment[atom_idx]
R = rotation_matrix_euler_zyz(self.segment_orientations)[seg_ids]
rotated_ref = torch.bmm(R, ref_perp.unsqueeze(-1)).squeeze(-1)
proj = torch.sum(rotated_ref * bc_unit, dim=-1, keepdim=True)
rotated_ref = rotated_ref - proj * bc_unit
ref_norm = torch.linalg.norm(
rotated_ref, dim=-1, keepdim=True
).clamp(min=1e-10)
rotated_ref = rotated_ref / ref_norm
theta_internal = torch.pi - angles
dx = bond_lengths * torch.cos(theta_internal)
dy = bond_lengths * torch.sin(theta_internal)
new_pos = (
parent_pos
+ dx.unsqueeze(-1) * bc_unit
+ dy.unsqueeze(-1) * rotated_ref
)
xyz[atom_idx] = new_pos
else:
# Depth-3+: standard NeRF formula
parent_idx = self._jsc_depth_parent_indices[depth_list_idx]
gp_idx = self._jsc_depth_gp_indices[depth_list_idx]
ggp_idx = self._jsc_depth_ggp_indices[depth_list_idx]
p1 = xyz[ggp_idx]
p2 = xyz[gp_idx]
p3 = xyz[parent_idx]
bond_idx = self._jsc_depth_bond_param_indices[depth_list_idx]
angle_idx = self._jsc_depth_angle_param_indices[depth_list_idx]
d = self.bond_lengths[bond_idx]
theta = self.angles[angle_idx]
# Junction sidechain torsions are in self.torsions (free DOFs)
torsion_idx = self.torsion_param_indices[atom_idx]
phi = self.torsions[torsion_idx]
bc = p3 - p2
bc = bc / torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
ab = p2 - p1
n = torch.linalg.cross(ab, bc)
n = n / torch.linalg.norm(n, dim=-1, keepdim=True).clamp(min=1e-10)
m = torch.linalg.cross(n, bc)
theta_internal = torch.pi - theta
sin_theta = torch.sin(theta_internal)
cos_theta = torch.cos(theta_internal)
dx = d * cos_theta
dy = d * sin_theta * torch.cos(phi)
dz = d * sin_theta * torch.sin(phi)
new_pos = (
p3
+ dx.unsqueeze(-1) * bc
+ dy.unsqueeze(-1) * m
- dz.unsqueeze(-1) * n
)
xyz[atom_idx] = new_pos
return xyz
# =========================================================================
# Interface methods (same API as SegmentedInternalCoordinateTensor)
# =========================================================================
[docs]
def shake(self, magnitude: float = 0.1) -> torch.Tensor:
"""Add Gaussian noise to internal parameters.
Perturbs torsions, bond lengths, and bond angles. Segment positions
and orientations are NOT perturbed because random independent
translation of segments creates gaps that exceed the junction chain's
reach. During optimization, the optimizer adjusts these rigid-body
DOFs smoothly.
"""
with torch.no_grad():
if self.bond_lengths.numel() > 0:
self.bond_lengths.data += (
torch.randn_like(self.bond_lengths) * magnitude
)
self.bond_lengths.data = torch.clamp(
self.bond_lengths.data, min=0.5
)
if self.angles.numel() > 0:
self.angles.data += torch.randn_like(self.angles) * magnitude
self.angles.data = torch.clamp(
self.angles.data, min=0.1, max=torch.pi - 0.1
)
if self.torsions.numel() > 0:
self.torsions.data += (
torch.randn_like(self.torsions) * magnitude
)
self.torsions.data = torch.atan2(
torch.sin(self.torsions.data),
torch.cos(self.torsions.data),
)
return self.forward()
[docs]
def fix(
self,
selection: Union[torch.Tensor, slice, None] = None,
freeze_at_current: bool = True,
) -> None:
"""Fix (freeze) atoms."""
if selection is None:
selection = slice(None)
mask = self._selection_to_mask(selection)
if freeze_at_current:
with torch.no_grad():
current_xyz = self.forward()
self.fixed_xyz[mask] = current_xyz[mask]
self.refinable_mask[mask] = False
[docs]
def freeze(
self,
selection: Union[torch.Tensor, slice, None] = None,
freeze_at_current: bool = True,
) -> None:
"""Alias for fix()."""
self.fix(selection, freeze_at_current)
[docs]
def refine(
self,
selection: Union[torch.Tensor, slice, None] = None,
rebuild: bool = True,
) -> None:
"""Make atoms refinable."""
if selection is None:
selection = slice(None)
mask = self._selection_to_mask(selection)
if rebuild:
self._update_internal_coords_from_xyz(self.fixed_xyz, mask)
self.refinable_mask[mask] = True
[docs]
def unfreeze(
self,
selection: Union[torch.Tensor, slice, None] = None,
rebuild: bool = True,
) -> None:
"""Alias for refine()."""
self.refine(selection, rebuild)
[docs]
def fix_all(self, freeze_at_current: bool = True) -> None:
"""Fix all atoms."""
self.fix(None, freeze_at_current)
[docs]
def freeze_all(self, freeze_at_current: bool = True) -> None:
"""Alias for fix_all()."""
self.fix_all(freeze_at_current)
[docs]
def refine_all(self, rebuild: bool = True) -> None:
"""Make all atoms refinable."""
self.refine(None, rebuild)
[docs]
def unfreeze_all(self, rebuild: bool = True) -> None:
"""Alias for refine_all()."""
self.refine_all(rebuild)
def _selection_to_mask(
self, selection: Union[torch.Tensor, slice, None]
) -> torch.Tensor:
"""Convert selection to boolean mask."""
if isinstance(selection, torch.Tensor) and selection.dtype == torch.bool:
return selection
elif isinstance(selection, torch.Tensor):
mask = torch.zeros(
self.n_atoms, dtype=torch.bool, device=self._device
)
mask[selection] = True
return mask
elif isinstance(selection, slice):
mask = torch.zeros(
self.n_atoms, dtype=torch.bool, device=self._device
)
mask[selection] = True
return mask
else:
raise TypeError(
f"selection must be Tensor, slice, or None, got {type(selection)}"
)
def _update_internal_coords_from_xyz(
self, xyz: torch.Tensor, mask: torch.Tensor
) -> None:
"""Update internal coordinate parameters from xyz."""
# Bonds
bond_update_mask = mask & (self.atom_depth >= 1)
if bond_update_mask.any():
child_xyz = xyz[bond_update_mask]
parent_xyz = xyz[self.parent_indices[bond_update_mask]]
new_bonds = torch.linalg.norm(child_xyz - parent_xyz, dim=-1)
bond_idx = self.bond_param_indices[bond_update_mask]
with torch.no_grad():
self.bond_lengths.data[bond_idx] = new_bonds
# Angles
angle_update_mask = mask & (self.atom_depth >= 2)
if angle_update_mask.any():
child_xyz = xyz[angle_update_mask]
parent_xyz = xyz[self.parent_indices[angle_update_mask]]
grandparent_xyz = xyz[self.grandparent_indices[angle_update_mask]]
v1 = child_xyz - parent_xyz
v2 = grandparent_xyz - parent_xyz
cos_angles = torch.sum(v1 * v2, dim=-1) / (
torch.linalg.norm(v1, dim=-1)
* torch.linalg.norm(v2, dim=-1)
+ 1e-10
)
cos_angles = torch.clamp(cos_angles, -1.0, 1.0)
new_angles = torch.acos(cos_angles)
angle_idx = self.angle_param_indices[angle_update_mask]
with torch.no_grad():
self.angles.data[angle_idx] = new_angles
# Torsions (only free torsions — excludes junction backbone)
torsion_update_mask = (
mask & (self.atom_depth >= 3) & ~self.is_junction_backbone
)
if torsion_update_mask.any():
child_xyz = xyz[torsion_update_mask]
parent_xyz = xyz[self.parent_indices[torsion_update_mask]]
grandparent_xyz = xyz[
self.grandparent_indices[torsion_update_mask]
]
great_grandparent_xyz = xyz[
self.great_grandparent_indices[torsion_update_mask]
]
new_torsions = self._compute_torsion_angles(
great_grandparent_xyz, grandparent_xyz, parent_xyz, child_xyz
)
torsion_idx = self.torsion_param_indices[torsion_update_mask]
with torch.no_grad():
self.torsions.data[torsion_idx] = new_torsions
# Segment positions
root_in_mask = mask[self.segment_roots]
if root_in_mask.any():
with torch.no_grad():
self.segment_positions.data[root_in_mask] = xyz[
self.segment_roots[root_in_mask]
]
# Secondary root positions
if self.secondary_root_indices.numel() > 0:
sec_root_in_mask = mask[self.secondary_root_indices]
if sec_root_in_mask.any():
with torch.no_grad():
self.secondary_root_positions.data[sec_root_in_mask] = xyz[
self.secondary_root_indices[sec_root_in_mask]
]
@property
def n_refinable(self) -> int:
return self.refinable_mask.sum().item()
@property
def n_fixed(self) -> int:
return (~self.refinable_mask).sum().item()
@property
def closure_residuals(self) -> Optional[torch.Tensor]:
"""Get last closure residuals from the junction solver."""
return self.junction_solver.closure_residuals
@property
def max_closure_gap(self) -> float:
"""Maximum closure gap in Angstroms across all junctions."""
if self.n_junctions == 0:
return 0.0
with torch.no_grad():
xyz = self.forward()
max_gap = 0.0
for junc_idx in range(self.n_junctions):
junc_info = self._junction_data[junc_idx]
# Check gap between last junction atom and first post-junction atom
last_res = junc_info["residues"][-1]
last_bb = junc_info["backbone"][-1]
c_pos = xyz[last_bb["C"]]
post_seg_idx = self._find_post_junction_segment(junc_idx)
if post_seg_idx is not None and post_seg_idx < len(self._planned_segments):
next_res = self._planned_segments[post_seg_idx][0]
if next_res in self._backbone_map:
n_pos = xyz[self._backbone_map[next_res]["N"]]
gap = torch.linalg.norm(c_pos - n_pos).item()
max_gap = max(max_gap, gap)
return max_gap
def __repr__(self) -> str:
n_secondary = self.secondary_root_indices.numel()
return (
f"ClosedSegmentedInternalCoordinateTensor("
f"n_atoms={self.n_atoms}, "
f"n_segments={self.n_segments}, "
f"n_junctions={self.n_junctions}, "
f"n_aa_per_segment={self.n_aa_per_segment}, "
f"junction_size={self.junction_size}, "
f"n_secondary_roots={n_secondary}, "
f"n_refinable={self.n_refinable}, n_fixed={self.n_fixed}, "
f"n_bonds={self.n_bonds}, n_angles={self.n_angles}, "
f"n_torsions={self.n_torsions}, n_rings={self.n_rings}, "
f"max_depth={self.max_depth}, device={self._device})"
)
# =========================================================================
# Helper functions (module-level)
# =========================================================================
def _compute_angle(
p1: torch.Tensor, p2: torch.Tensor, p3: torch.Tensor,
) -> torch.Tensor:
"""Compute angle at p2 between p1-p2-p3."""
v1 = p1 - p2
v2 = p3 - p2
cos_a = torch.dot(v1, v2) / (
torch.linalg.norm(v1) * torch.linalg.norm(v2) + 1e-10
)
cos_a = torch.clamp(cos_a, -1.0, 1.0)
return torch.acos(cos_a)
def _compute_torsion(
p1: torch.Tensor, p2: torch.Tensor, p3: torch.Tensor, p4: torch.Tensor,
) -> torch.Tensor:
"""Compute torsion angle for four points."""
b1 = p2 - p1
b2 = p3 - p2
b3 = p4 - p3
n1 = torch.linalg.cross(b1, b2)
n2 = torch.linalg.cross(b2, b3)
n1 = n1 / torch.linalg.norm(n1).clamp(min=1e-10)
n2 = n2 / torch.linalg.norm(n2).clamp(min=1e-10)
b2_norm = b2 / torch.linalg.norm(b2).clamp(min=1e-10)
m1 = torch.linalg.cross(n1, b2_norm)
x = torch.dot(n1, n2)
y = torch.dot(m1, n2)
return torch.atan2(y, x)