"""
Internal coordinate parametrization for atomic structures.
This module provides the InternalCoordinateTensor class which parametrizes
atomic XYZ coordinates using internal coordinates (bond lengths, angles,
torsions) instead of Cartesian coordinates. This enables physically meaningful
perturbations and differentiable reconstruction.
Key features:
- Bond detection: Atoms within 2Ã… are considered bonded (using torch.cdist)
- Internal coordinate parametrization: N atoms → N-1 bonds, N-2 angles, N-3 torsions per chain
- Chain handling: Unconnected chains treated as rigid groups with position/orientation
- Fully differentiable: Complete gradient flow from internal params → Cartesian coords
- Fully vectorized: No Python loops over atoms - all operations via tensor ops
- Ring handling: Rings are treated as rigid entities with only the anchor movable
"""
from typing import Optional, Union
import torch
import torch.nn as nn
from torchref.base.alignment.rotation import axis_angle_to_rotation_matrix
from torchref.utils.device_mixin import DeviceMixin
[docs]
class InternalCoordinateTensor(DeviceMixin, nn.Module):
"""
Parameter wrapper using internal coordinates (Z-matrix style).
Stores: bond_lengths, angles, torsions, chain_positions, chain_orientations
Reconstructs: Cartesian xyz on forward()
This provides a physically meaningful parametrization of atomic coordinates
where perturbations correspond to changes in bond lengths, angles, and
torsion angles rather than arbitrary Cartesian displacements.
Parameters
----------
initial_xyz : torch.Tensor
Initial Cartesian coordinates of shape (N, 3).
bond_cutoff : float, optional
Distance cutoff for bond detection in Angstroms. Default is 2.0.
requires_grad : bool, optional
Whether parameters should have gradients. Default is True.
dtype : torch.dtype, optional
Data type for tensors. Default is same as initial_xyz.
device : torch.device, optional
Device for tensors. Default is same as initial_xyz.
Attributes
----------
n_atoms : int
Number of atoms.
n_chains : int
Number of disconnected chains.
max_depth : int
Maximum depth in the spanning tree.
bond_lengths : nn.Parameter
Bond length parameters in Angstroms.
angles : nn.Parameter
Angle parameters in radians.
torsions : nn.Parameter
Torsion angle parameters in radians.
chain_positions : nn.Parameter
Absolute positions of chain root atoms.
chain_orientations : nn.Parameter
Axis-angle orientations for each chain.
"""
[docs]
def __init__(
self,
initial_xyz: torch.Tensor,
bond_cutoff: float = 2.0,
requires_grad: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
"""
Initialize InternalCoordinateTensor from Cartesian coordinates.
Parameters
----------
initial_xyz : torch.Tensor
Initial Cartesian coordinates of shape (N, 3).
bond_cutoff : float, optional
Distance cutoff for bond detection in Angstroms. Default is 2.0.
requires_grad : bool, optional
Whether parameters should have gradients. Default is True.
dtype : torch.dtype, optional
Data type for tensors. Default is same as initial_xyz.
device : torch.device, optional
Device for tensors. Default is same as initial_xyz.
"""
super().__init__()
if dtype is None:
dtype = initial_xyz.dtype
if device is None:
device = initial_xyz.device
self._dtype = dtype
# Internal storage is permanently on CPU. The spanning-tree build,
# ring detection, per-atom dihedral extraction, and the parallel-scan
# forward pass are all dominated by sequential indexed tensor access
# which suffers from per-op MPS/CUDA dispatch overhead. Keeping the
# tensors on CPU makes them ~100x faster on Apple Silicon.
# ``_output_device`` is the device where forward()'s result is
# delivered; ``.to(device)`` updates it without migrating the
# internal state.
self._output_device = (
torch.device(device) if not isinstance(device, torch.device) else device
)
self._device = torch.device("cpu")
self.n_atoms = initial_xyz.shape[0]
self.bond_cutoff = bond_cutoff
initial_xyz_cpu = initial_xyz.to(dtype=dtype, device=self._device)
# Build molecular graph from distances
adjacency = self._build_molecular_graph(initial_xyz_cpu, bond_cutoff)
# Build spanning trees for each connected component
self._build_spanning_trees(adjacency)
# Detect rings and setup rigid ring handling
self._detect_rings(adjacency)
# Extract internal coordinates from initial xyz
self._extract_internal_coords(initial_xyz_cpu, requires_grad)
@property
def dtype(self):
"""Return the dtype of tensors."""
return self._dtype
@property
def device(self):
"""Logical device — where forward()'s result is delivered.
Internal parameters/buffers stay on CPU regardless; this is the
device requested by the caller (e.g. via ``.to('mps')``) and is
the device the forward output is migrated to.
"""
return self._output_device
[docs]
def to(self, *args, **kwargs): # type: ignore[override]
"""Update output device and optionally cast dtype.
Unlike ``DeviceMixin.to``, this does **not** move internal
parameters/buffers to ``device`` — they stay on CPU to avoid
the per-op dispatch overhead of MPS/CUDA on the sequential
spanning-tree + parallel-scan code. The ``device`` argument
only updates ``_output_device``; ``dtype`` still propagates
normally and recasts all CPU tensors.
"""
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)
for a in args:
if isinstance(a, torch.device):
device = a
elif isinstance(a, str):
device = a
elif isinstance(a, torch.dtype):
dtype = a
if device is not None:
self._output_device = (
torch.device(device) if not isinstance(device, torch.device) else device
)
if dtype is not None:
nn.Module.to(self, dtype=dtype)
self._dtype = dtype
return self
[docs]
def cuda(self, device=None): # type: ignore[override]
if device is None:
target = torch.device("cuda")
elif isinstance(device, int):
target = torch.device(f"cuda:{device}")
else:
target = torch.device(device)
self._output_device = target
return self
[docs]
def cpu(self): # type: ignore[override]
self._output_device = torch.device("cpu")
return self
def _to_output(self, xyz: torch.Tensor) -> torch.Tensor:
"""Migrate a CPU result tensor to the configured output device."""
if xyz.device != self._output_device:
return xyz.to(self._output_device)
return xyz
@staticmethod
def _build_molecular_graph(
xyz: torch.Tensor, cutoff: float = 2.0
) -> torch.Tensor:
"""
Build adjacency matrix from atomic distances.
Parameters
----------
xyz : torch.Tensor
Atomic coordinates of shape (N, 3).
cutoff : float
Distance cutoff for bonds in Angstroms.
Returns
-------
torch.Tensor
Boolean adjacency matrix of shape (N, N).
"""
# Compute pairwise distances
distances = torch.cdist(xyz, xyz)
# Atoms within cutoff are bonded (exclude self-connections)
adjacency = (distances < cutoff) & (distances > 0.1)
return adjacency
def _build_spanning_trees(self, adjacency: torch.Tensor) -> None:
"""
Build spanning tree(s) via BFS and populate index buffers.
For each connected component, builds a spanning tree and records
parent/grandparent/great-grandparent indices for each atom.
Parameters
----------
adjacency : torch.Tensor
Boolean adjacency matrix of shape (N, N).
"""
n_atoms = self.n_atoms
device = self._device
# Initialize arrays
parent_indices = torch.full((n_atoms,), -1, dtype=torch.long, device=device)
grandparent_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
great_grandparent_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
atom_depth = torch.full((n_atoms,), -1, dtype=torch.long, device=device)
chain_ids = torch.full((n_atoms,), -1, dtype=torch.long, device=device)
chain_roots = []
visited = torch.zeros(n_atoms, dtype=torch.bool, device=device)
chain_id = 0
# Process each connected component
for start_atom in range(n_atoms):
if visited[start_atom]:
continue
# BFS from this atom
chain_roots.append(start_atom)
queue = [start_atom]
visited[start_atom] = True
parent_indices[start_atom] = -1
atom_depth[start_atom] = 0
chain_ids[start_atom] = chain_id
queue_idx = 0
while queue_idx < len(queue):
current = queue[queue_idx]
queue_idx += 1
# Find neighbors
neighbors = torch.where(adjacency[current])[0]
for neighbor in neighbors:
neighbor = neighbor.item()
if not visited[neighbor]:
visited[neighbor] = True
parent_indices[neighbor] = current
atom_depth[neighbor] = atom_depth[current] + 1
chain_ids[neighbor] = chain_id
# Set grandparent
if parent_indices[current] >= 0:
grandparent_indices[neighbor] = parent_indices[current]
# Set great-grandparent
if grandparent_indices[current] >= 0:
great_grandparent_indices[neighbor] = grandparent_indices[
current
]
queue.append(neighbor)
chain_id += 1
# Register buffers
self.register_buffer("parent_indices", parent_indices)
self.register_buffer("grandparent_indices", grandparent_indices)
self.register_buffer("great_grandparent_indices", great_grandparent_indices)
self.register_buffer("atom_depth", atom_depth)
self.register_buffer("chain_ids", chain_ids)
self.register_buffer(
"chain_roots", torch.tensor(chain_roots, dtype=torch.long, device=device)
)
self.n_chains = len(chain_roots)
self.max_depth = atom_depth.max().item()
# Build parameter index mappings
self._build_param_indices()
# Pre-compute depth indices for fast forward pass
self._build_depth_indices()
def _build_param_indices(self) -> None:
"""
Build mappings from atoms to parameter indices.
Creates bond_param_indices, angle_param_indices, and torsion_param_indices
which map each atom to its corresponding internal coordinate parameter.
"""
device = self._device
n_atoms = self.n_atoms
# Atoms with depth >= 1 have bonds
# Atoms with depth >= 2 have angles
# Atoms with depth >= 3 have torsions
# Count parameters
has_bond = self.atom_depth >= 1
has_angle = self.atom_depth >= 2
has_torsion = self.atom_depth >= 3
n_bonds = has_bond.sum().item()
n_angles = has_angle.sum().item()
n_torsions = has_torsion.sum().item()
# Create index mappings
bond_param_indices = torch.full((n_atoms,), -1, dtype=torch.long, device=device)
angle_param_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
torsion_param_indices = torch.full(
(n_atoms,), -1, dtype=torch.long, device=device
)
# Assign indices in order of atom index
bond_idx = 0
angle_idx = 0
torsion_idx = 0
for i in range(n_atoms):
if has_bond[i]:
bond_param_indices[i] = bond_idx
bond_idx += 1
if has_angle[i]:
angle_param_indices[i] = angle_idx
angle_idx += 1
if has_torsion[i]:
torsion_param_indices[i] = torsion_idx
torsion_idx += 1
self.register_buffer("bond_param_indices", bond_param_indices)
self.register_buffer("angle_param_indices", angle_param_indices)
self.register_buffer("torsion_param_indices", torsion_param_indices)
self.n_bonds = n_bonds
self.n_angles = n_angles
self.n_torsions = n_torsions
def _build_depth_indices(self) -> None:
"""
Pre-compute atom indices for each depth level.
This avoids recomputing masks during the forward pass, which is
the main performance bottleneck for large structures.
"""
# Store indices for each depth level (depths 3+ use full NeRF)
# We store as a list of tensors since depths have varying atom counts
depth_atom_indices = []
depth_parent_indices = []
depth_gp_indices = []
depth_ggp_indices = []
depth_bond_param_indices = []
depth_angle_param_indices = []
depth_torsion_param_indices = []
for d in range(3, self.max_depth + 1):
mask = self.atom_depth == d
if mask.any():
atom_idx = torch.where(mask)[0]
depth_atom_indices.append(atom_idx)
depth_parent_indices.append(self.parent_indices[atom_idx])
depth_gp_indices.append(self.grandparent_indices[atom_idx])
depth_ggp_indices.append(self.great_grandparent_indices[atom_idx])
depth_bond_param_indices.append(self.bond_param_indices[atom_idx])
depth_angle_param_indices.append(self.angle_param_indices[atom_idx])
depth_torsion_param_indices.append(self.torsion_param_indices[atom_idx])
else:
# Empty tensor for this depth
empty = torch.tensor([], dtype=torch.long, device=self._device)
depth_atom_indices.append(empty)
depth_parent_indices.append(empty)
depth_gp_indices.append(empty)
depth_ggp_indices.append(empty)
depth_bond_param_indices.append(empty)
depth_angle_param_indices.append(empty)
depth_torsion_param_indices.append(empty)
# Store as instance attributes (not buffers since they're lists)
self._depth_atom_indices = depth_atom_indices
self._depth_parent_indices = depth_parent_indices
self._depth_gp_indices = depth_gp_indices
self._depth_ggp_indices = depth_ggp_indices
self._depth_bond_param_indices = depth_bond_param_indices
self._depth_angle_param_indices = depth_angle_param_indices
self._depth_torsion_param_indices = depth_torsion_param_indices
def _detect_rings(self, adjacency: torch.Tensor) -> None:
"""
Detect rings in the molecular graph and set up rigid ring handling.
Rings are identified by finding back-edges in the spanning tree.
Ring atoms are marked for special handling during reconstruction.
Parameters
----------
adjacency : torch.Tensor
Boolean adjacency matrix of shape (N, N).
"""
device = self._device
n_atoms = self.n_atoms
# Initialize ring detection buffers
ring_member_mask = torch.zeros(n_atoms, dtype=torch.bool, device=device)
ring_group_id = torch.full((n_atoms,), -1, dtype=torch.long, device=device)
# Find back edges (edges that are not in the spanning tree)
# A back edge connects an atom to a non-parent ancestor
back_edges = []
for i in range(n_atoms):
neighbors = torch.where(adjacency[i])[0]
for j in neighbors:
j = j.item()
if j <= i:
continue # Only check each edge once
# Check if this edge is not in the spanning tree
if self.parent_indices[i] != j and self.parent_indices[j] != i:
# This is a back edge - indicates a ring
back_edges.append((i, j))
# For each back edge, find the ring members
ring_anchors = []
ring_members_list = []
ring_idx = 0
for edge_i, edge_j in back_edges:
# Find path from i to j through the tree
# The ring consists of atoms on both paths to their common ancestor
# Find ancestors of i
ancestors_i = set()
current = edge_i
while current >= 0:
ancestors_i.add(current)
current = self.parent_indices[current].item()
# Find common ancestor with j
ring_atoms = []
current = edge_j
while current >= 0:
ring_atoms.append(current)
if current in ancestors_i:
# Found common ancestor
common_ancestor = current
break
current = self.parent_indices[current].item()
# Add path from i to common ancestor
current = edge_i
while current != common_ancestor:
ring_atoms.append(current)
current = self.parent_indices[current].item()
# Mark ring atoms
for atom in ring_atoms:
ring_member_mask[atom] = True
if ring_group_id[atom] < 0: # Only set if not already in a ring
ring_group_id[atom] = ring_idx
# The anchor is the atom closest to root (smallest depth)
depths = torch.tensor(
[self.atom_depth[a].item() for a in ring_atoms], device=device
)
anchor_idx = ring_atoms[depths.argmin().item()]
ring_anchors.append(anchor_idx)
ring_members_list.append(ring_atoms)
ring_idx += 1
self.register_buffer("ring_member_mask", ring_member_mask)
self.register_buffer("ring_group_id", ring_group_id)
self.n_rings = len(ring_anchors)
if self.n_rings > 0:
self.register_buffer(
"ring_anchor_atoms",
torch.tensor(ring_anchors, dtype=torch.long, device=device),
)
# Store ring members as padded tensor
max_ring_size = max(len(r) for r in ring_members_list)
ring_members_tensor = torch.full(
(self.n_rings, max_ring_size), -1, dtype=torch.long, device=device
)
ring_sizes = torch.zeros(self.n_rings, dtype=torch.long, device=device)
for i, members in enumerate(ring_members_list):
ring_sizes[i] = len(members)
ring_members_tensor[i, : len(members)] = torch.tensor(
members, dtype=torch.long, device=device
)
self.register_buffer("ring_members", ring_members_tensor)
self.register_buffer("ring_sizes", ring_sizes)
else:
self.register_buffer(
"ring_anchor_atoms", torch.tensor([], dtype=torch.long, device=device)
)
self.register_buffer(
"ring_members",
torch.tensor([], dtype=torch.long, device=device).reshape(0, 0),
)
self.register_buffer(
"ring_sizes", torch.tensor([], dtype=torch.long, device=device)
)
def _extract_internal_coords(
self, xyz: torch.Tensor, requires_grad: bool = True
) -> None:
"""
Extract internal coordinates from Cartesian coordinates.
Computes bond lengths, angles, and torsions from the initial xyz
and creates parameters for them.
Parameters
----------
xyz : torch.Tensor
Atomic coordinates of shape (N, 3).
requires_grad : bool
Whether parameters should have gradients.
"""
device = self._device
dtype = self._dtype
# Extract bond lengths for atoms with depth >= 1
bond_mask = self.atom_depth >= 1
if bond_mask.any():
child_xyz = xyz[bond_mask]
parent_xyz = xyz[self.parent_indices[bond_mask]]
bond_lengths = torch.linalg.norm(child_xyz - parent_xyz, dim=-1)
else:
bond_lengths = torch.tensor([], dtype=dtype, device=device)
# Extract angles for atoms with depth >= 2
angle_mask = self.atom_depth >= 2
if angle_mask.any():
child_xyz = xyz[angle_mask]
parent_xyz = xyz[self.parent_indices[angle_mask]]
grandparent_xyz = xyz[self.grandparent_indices[angle_mask]]
# Angle at parent between child and grandparent
v1 = child_xyz - parent_xyz
v2 = grandparent_xyz - parent_xyz
cos_angles = torch.sum(v1 * v2, dim=-1) / (
torch.linalg.norm(v1, dim=-1) * torch.linalg.norm(v2, dim=-1) + 1e-10
)
cos_angles = torch.clamp(cos_angles, -1.0, 1.0)
angles = torch.acos(cos_angles)
else:
angles = torch.tensor([], dtype=dtype, device=device)
# Extract torsions for atoms with depth >= 3
torsion_mask = self.atom_depth >= 3
if torsion_mask.any():
child_xyz = xyz[torsion_mask]
parent_xyz = xyz[self.parent_indices[torsion_mask]]
grandparent_xyz = xyz[self.grandparent_indices[torsion_mask]]
great_grandparent_xyz = xyz[self.great_grandparent_indices[torsion_mask]]
torsions = self._compute_torsion_angles(
great_grandparent_xyz, grandparent_xyz, parent_xyz, child_xyz
)
else:
torsions = torch.tensor([], dtype=dtype, device=device)
# Extract chain positions (positions of root atoms)
chain_positions = xyz[self.chain_roots].clone()
# Initialize chain orientations to zero (no rotation)
chain_orientations = torch.zeros(self.n_chains, 3, dtype=dtype, device=device)
# Extract ring local coordinates if there are rings
if self.n_rings > 0:
ring_local_coords = self._extract_ring_local_coords(xyz)
self.register_buffer("ring_local_coords", ring_local_coords)
else:
self.register_buffer(
"ring_local_coords",
torch.tensor([], dtype=dtype, device=device).reshape(0, 0, 3),
)
# Store second atom directions for each chain (for placing second atoms)
second_atom_dirs = []
for chain_idx in range(self.n_chains):
root = self.chain_roots[chain_idx].item()
# Find atoms with depth 1 in this chain
mask = (self.chain_ids == chain_idx) & (self.atom_depth == 1)
if mask.any():
first_child = torch.where(mask)[0][0].item()
direction = xyz[first_child] - xyz[root]
norm = torch.linalg.norm(direction)
if norm > 1e-10:
direction = direction / norm
else:
direction = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device)
else:
direction = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device)
second_atom_dirs.append(direction)
self.register_buffer(
"second_atom_dirs", torch.stack(second_atom_dirs) if second_atom_dirs else
torch.zeros(0, 3, dtype=dtype, device=device)
)
# For depth-2 atoms, we need to store reference perpendiculars
# These are computed for each individual depth-2 atom
depth2_mask = self.atom_depth == 2
n_depth2 = depth2_mask.sum().item()
# Create mapping from atom index to depth2_perps index
# -1 for atoms that aren't depth-2
depth2_atom_to_perp_idx = torch.full(
(self.n_atoms,), -1, dtype=torch.long, device=device
)
if n_depth2 > 0:
depth2_perps = torch.zeros(n_depth2, 3, dtype=dtype, device=device)
depth2_indices = torch.where(depth2_mask)[0]
for i, atom_idx in enumerate(depth2_indices):
atom_idx_val = atom_idx.item()
depth2_atom_to_perp_idx[atom_idx_val] = i
parent = self.parent_indices[atom_idx_val].item()
grandparent = self.grandparent_indices[atom_idx_val].item()
# Direction from grandparent to parent
v_bond = xyz[parent] - xyz[grandparent]
v_bond_norm = torch.linalg.norm(v_bond)
if v_bond_norm > 1e-10:
v_bond = v_bond / v_bond_norm
else:
v_bond = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device)
# Direction from parent to child
v_child = xyz[atom_idx_val] - xyz[parent]
# Get perpendicular component
v_perp = v_child - torch.dot(v_child, v_bond) * v_bond
v_perp_norm = torch.linalg.norm(v_perp)
if v_perp_norm > 1e-10:
v_perp = v_perp / v_perp_norm
else:
# Collinear case - use arbitrary perpendicular
if abs(v_bond[0]) < 0.9:
v_perp = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device)
else:
v_perp = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device)
v_perp = v_perp - torch.dot(v_perp, v_bond) * v_bond
v_perp = v_perp / torch.linalg.norm(v_perp)
depth2_perps[i] = v_perp
self.register_buffer("depth2_perps", depth2_perps)
else:
self.register_buffer(
"depth2_perps", torch.zeros(0, 3, dtype=dtype, device=device)
)
self.register_buffer("depth2_atom_to_perp_idx", depth2_atom_to_perp_idx)
# Register parameters
self.bond_lengths = nn.Parameter(bond_lengths.clone(), requires_grad=requires_grad)
self.angles = nn.Parameter(angles.clone(), requires_grad=requires_grad)
self.torsions = nn.Parameter(torsions.clone(), requires_grad=requires_grad)
self.chain_positions = nn.Parameter(
chain_positions.clone(), requires_grad=requires_grad
)
self.chain_orientations = nn.Parameter(
chain_orientations.clone(), requires_grad=requires_grad
)
# Initialize refinable mask (all atoms refinable by default)
# True = use internal coordinates (refinable), False = use fixed xyz (frozen)
self.register_buffer(
"refinable_mask",
torch.ones(self.n_atoms, dtype=torch.bool, device=device)
)
# Buffer to store frozen atom coordinates
self.register_buffer(
"fixed_xyz",
xyz.clone()
)
def _extract_ring_local_coords(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Extract local coordinates for ring atoms relative to ring anchors.
Parameters
----------
xyz : torch.Tensor
Atomic coordinates of shape (N, 3).
Returns
-------
torch.Tensor
Ring local coordinates of shape (n_rings, max_ring_size, 3).
"""
device = self._device
dtype = self._dtype
max_ring_size = self.ring_members.shape[1] if self.n_rings > 0 else 0
ring_local_coords = torch.zeros(
self.n_rings, max_ring_size, 3, dtype=dtype, device=device
)
for ring_idx in range(self.n_rings):
anchor = self.ring_anchor_atoms[ring_idx].item()
anchor_pos = xyz[anchor]
# Compute local frame at anchor (using parent direction)
if self.parent_indices[anchor] >= 0:
parent_pos = xyz[self.parent_indices[anchor]]
z_axis = anchor_pos - parent_pos
z_norm = torch.linalg.norm(z_axis)
if z_norm > 1e-10:
z_axis = z_axis / z_norm
else:
z_axis = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device)
else:
z_axis = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device)
# Arbitrary perpendicular axes
if abs(z_axis[0]) < 0.9:
x_axis = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device)
else:
x_axis = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device)
x_axis = x_axis - torch.dot(x_axis, z_axis) * z_axis
x_axis = x_axis / torch.linalg.norm(x_axis)
y_axis = torch.linalg.cross(z_axis, x_axis)
R = torch.stack([x_axis, y_axis, z_axis], dim=1) # (3, 3)
# Store local coords for each ring member
for local_idx in range(self.ring_sizes[ring_idx].item()):
atom_idx = self.ring_members[ring_idx, local_idx].item()
if atom_idx >= 0:
global_offset = xyz[atom_idx] - anchor_pos
local_offset = R.T @ global_offset # Transform to local frame
ring_local_coords[ring_idx, local_idx] = local_offset
return ring_local_coords
@staticmethod
def _compute_torsion_angles(
p1: torch.Tensor, p2: torch.Tensor, p3: torch.Tensor, p4: torch.Tensor
) -> torch.Tensor:
"""
Compute torsion angles for batched positions.
Parameters
----------
p1, p2, p3, p4 : torch.Tensor
Atom positions of shape (N, 3) defining torsion angle p1-p2-p3-p4.
Returns
-------
torch.Tensor
Torsion angles in radians, shape (N,).
"""
# Vectors along the backbone
b1 = p2 - p1
b2 = p3 - p2
b3 = p4 - p3
# Normal vectors to planes
n1 = torch.linalg.cross(b1, b2)
n2 = torch.linalg.cross(b2, b3)
# Normalize
n1 = n1 / torch.linalg.norm(n1, dim=-1, keepdim=True).clamp(min=1e-10)
n2 = n2 / torch.linalg.norm(n2, dim=-1, keepdim=True).clamp(min=1e-10)
# Unit vector along b2
b2_norm = b2 / torch.linalg.norm(b2, dim=-1, keepdim=True).clamp(min=1e-10)
# m1 is perpendicular to n1 and b2
m1 = torch.linalg.cross(n1, b2_norm)
# Compute torsion angle using atan2
x = torch.sum(n1 * n2, dim=-1)
y = torch.sum(m1 * n2, dim=-1)
return torch.atan2(y, x)
def _place_second_atoms(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Place second atoms of each chain (depth=1) in parallel.
Parameters
----------
xyz : torch.Tensor
Current coordinates of shape (N, 3).
Returns
-------
torch.Tensor
Updated coordinates.
"""
# Find atoms with depth 1
mask = self.atom_depth == 1
if not mask.any():
return xyz
xyz = xyz.clone()
# Get chain id for each depth-1 atom
chain_ids = self.chain_ids[mask]
# Get root positions for these chains
root_positions = self.chain_positions[chain_ids] # (n_second, 3)
# Get chain orientations and apply to second atom directions
chain_orients = self.chain_orientations[chain_ids] # (n_second, 3)
R_matrices = axis_angle_to_rotation_matrix(chain_orients) # (n_second, 3, 3)
# Get base directions for these chains
base_dirs = self.second_atom_dirs[chain_ids] # (n_second, 3)
# Rotate directions
rotated_dirs = torch.bmm(R_matrices, base_dirs.unsqueeze(-1)).squeeze(-1)
# Get bond lengths for these atoms
bond_idx = self.bond_param_indices[mask]
bond_lengths = self.bond_lengths[bond_idx] # (n_second,)
# Compute positions
new_positions = root_positions + bond_lengths.unsqueeze(-1) * rotated_dirs
xyz[mask] = new_positions
return xyz
def _place_third_atoms(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Place third atoms of each chain (depth=2) in parallel.
Parameters
----------
xyz : torch.Tensor
Current coordinates of shape (N, 3).
Returns
-------
torch.Tensor
Updated coordinates.
"""
# Find atoms with depth 2
mask = self.atom_depth == 2
if not mask.any():
return xyz
xyz = xyz.clone()
# Get parent and grandparent positions
parent_idx = self.parent_indices[mask]
grandparent_idx = self.grandparent_indices[mask]
parent_pos = xyz[parent_idx] # (n_third, 3)
grandparent_pos = xyz[grandparent_idx] # (n_third, 3)
# Get bond lengths and angles
bond_idx = self.bond_param_indices[mask]
angle_idx = self.angle_param_indices[mask]
bond_lengths = self.bond_lengths[bond_idx] # (n_third,)
angles = self.angles[angle_idx] # (n_third,)
# Build local frame at parent
bc = parent_pos - grandparent_pos # (n_third, 3)
bc_norm = torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
bc_unit = bc / bc_norm
# Get stored perpendicular directions for depth-2 atoms
# Use the atom-to-perp index mapping
atom_indices = torch.where(mask)[0]
perp_idx = self.depth2_atom_to_perp_idx[atom_indices]
ref_perp = self.depth2_perps[perp_idx] # (n_third, 3)
# Apply chain rotation to reference perpendicular
chain_ids = self.chain_ids[mask]
chain_orients = self.chain_orientations[chain_ids]
R_matrices = axis_angle_to_rotation_matrix(chain_orients)
rotated_ref = torch.bmm(R_matrices, ref_perp.unsqueeze(-1)).squeeze(-1)
# Make sure rotated_ref is perpendicular to bc_unit
proj = torch.sum(rotated_ref * bc_unit, dim=-1, keepdim=True)
rotated_ref = rotated_ref - proj * bc_unit
ref_norm = torch.linalg.norm(rotated_ref, dim=-1, keepdim=True).clamp(min=1e-10)
rotated_ref = rotated_ref / ref_norm
# Compute new position using bond length and angle
# The angle is the bond angle at parent: grandparent-parent-child
# For angle=pi (linear), child continues along the bc direction
# For angle < pi, child deviates by (pi - angle) from the bc direction
theta_internal = torch.pi - angles
dx = bond_lengths * torch.cos(theta_internal)
dy = bond_lengths * torch.sin(theta_internal)
# Child position: along bc direction (continuing from grandparent through parent)
# plus perpendicular offset for non-linear angles
new_positions = parent_pos + dx.unsqueeze(-1) * bc_unit + dy.unsqueeze(-1) * rotated_ref
xyz[mask] = new_positions
return xyz
def _place_atoms_at_depth(self, xyz: torch.Tensor, depth: int) -> torch.Tensor:
"""
Place all atoms at given depth in parallel (fully vectorized).
Uses the NeRF algorithm to reconstruct Cartesian coordinates from
bond lengths, angles, and torsions.
Parameters
----------
xyz : torch.Tensor
Current coordinates of shape (N, 3).
depth : int
Tree depth to process.
Returns
-------
torch.Tensor
Updated coordinates.
"""
# Get all atom indices at this depth (pre-computed mask)
mask = self.atom_depth == depth
n_atoms_at_depth = mask.sum()
if n_atoms_at_depth == 0:
return xyz
xyz = xyz.clone()
# Gather reference positions via advanced indexing (batched)
parent_idx = self.parent_indices[mask]
gp_idx = self.grandparent_indices[mask]
ggp_idx = self.great_grandparent_indices[mask]
p1 = xyz[ggp_idx] # (n_atoms_at_depth, 3)
p2 = xyz[gp_idx] # (n_atoms_at_depth, 3)
p3 = xyz[parent_idx] # (n_atoms_at_depth, 3)
# Gather internal coordinates (batched)
bond_idx = self.bond_param_indices[mask]
angle_idx = self.angle_param_indices[mask]
torsion_idx = self.torsion_param_indices[mask]
d = self.bond_lengths[bond_idx] # (n_atoms_at_depth,)
theta = self.angles[angle_idx] # (n_atoms_at_depth,)
phi = self.torsions[torsion_idx] # (n_atoms_at_depth,)
# Build local coordinate frame (batched, all in parallel)
bc = p3 - p2 # (n_atoms_at_depth, 3)
bc = bc / torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
ab = p2 - p1
n = torch.linalg.cross(ab, bc)
n = n / torch.linalg.norm(n, dim=-1, keepdim=True).clamp(min=1e-10)
m = torch.linalg.cross(n, bc)
# Compute positions (batched trigonometry)
theta_internal = torch.pi - theta
sin_theta = torch.sin(theta_internal)
cos_theta = torch.cos(theta_internal)
dx = d * cos_theta # (n_atoms_at_depth,)
dy = d * sin_theta * torch.cos(phi)
dz = d * sin_theta * torch.sin(phi)
# Global positions (batched)
# Child continues along bc direction (from grandparent through parent)
# The torsion angle determines rotation around bc axis
# Negate dz to match torsion sign convention
new_pos = (
p3
+ dx.unsqueeze(-1) * bc
+ dy.unsqueeze(-1) * m
- dz.unsqueeze(-1) * n
) # (n_atoms_at_depth, 3)
# Scatter back to xyz
xyz[mask] = new_pos
return xyz
def _place_atoms_at_depth_fast(
self, xyz: torch.Tensor, depth_idx: int
) -> torch.Tensor:
"""
Place atoms at a depth level using pre-computed indices (fast path).
This avoids recomputing masks during the forward pass.
Parameters
----------
xyz : torch.Tensor
Current coordinates of shape (N, 3).
depth_idx : int
Index into pre-computed depth arrays (0 = depth 3, 1 = depth 4, etc.)
Returns
-------
torch.Tensor
Updated coordinates.
"""
atom_idx = self._depth_atom_indices[depth_idx]
if atom_idx.numel() == 0:
return xyz
xyz = xyz.clone()
# Use pre-computed indices
parent_idx = self._depth_parent_indices[depth_idx]
gp_idx = self._depth_gp_indices[depth_idx]
ggp_idx = self._depth_ggp_indices[depth_idx]
p1 = xyz[ggp_idx]
p2 = xyz[gp_idx]
p3 = xyz[parent_idx]
# Use pre-computed parameter indices
bond_idx = self._depth_bond_param_indices[depth_idx]
angle_idx = self._depth_angle_param_indices[depth_idx]
torsion_idx = self._depth_torsion_param_indices[depth_idx]
d = self.bond_lengths[bond_idx]
theta = self.angles[angle_idx]
phi = self.torsions[torsion_idx]
# Build local coordinate frame
bc = p3 - p2
bc = bc / torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
ab = p2 - p1
n = torch.linalg.cross(ab, bc)
n = n / torch.linalg.norm(n, dim=-1, keepdim=True).clamp(min=1e-10)
m = torch.linalg.cross(n, bc)
# Compute positions
theta_internal = torch.pi - theta
sin_theta = torch.sin(theta_internal)
cos_theta = torch.cos(theta_internal)
dx = d * cos_theta
dy = d * sin_theta * torch.cos(phi)
dz = d * sin_theta * torch.sin(phi)
new_pos = (
p3
+ dx.unsqueeze(-1) * bc
+ dy.unsqueeze(-1) * m
- dz.unsqueeze(-1) * n
)
# Scatter back using pre-computed atom indices
xyz[atom_idx] = new_pos
return xyz
def _place_rigid_rings(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Place ring atoms as rigid groups around anchor (vectorized).
Ring atoms (except anchor) are placed at fixed offsets from
the anchor atom in its local coordinate frame.
Parameters
----------
xyz : torch.Tensor
Current coordinates of shape (N, 3).
Returns
-------
torch.Tensor
Updated coordinates with ring atoms placed.
"""
if self.n_rings == 0:
return xyz
# Get anchor positions and parent positions (vectorized)
anchor_pos = xyz[self.ring_anchor_atoms] # (n_rings, 3)
parent_idx = self.parent_indices[self.ring_anchor_atoms] # (n_rings,)
# Handle anchors with valid parents
valid_parent_mask = parent_idx >= 0
parent_pos = torch.zeros_like(anchor_pos)
parent_pos[valid_parent_mask] = xyz[parent_idx[valid_parent_mask]]
# Compute z-axis for all rings (vectorized)
z_axis = anchor_pos - parent_pos
z_norm = torch.linalg.norm(z_axis, dim=-1, keepdim=True)
z_axis = torch.where(
z_norm > 1e-10,
z_axis / z_norm.clamp(min=1e-10),
torch.tensor([0.0, 0.0, 1.0], dtype=self._dtype, device=self._device)
)
# For anchors without valid parents, use default z-axis
z_axis[~valid_parent_mask] = torch.tensor(
[0.0, 0.0, 1.0], dtype=self._dtype, device=self._device
)
# Compute perpendicular axes (vectorized)
# Use [1,0,0] unless z is nearly parallel to it
x_base = torch.zeros_like(z_axis)
x_base[:, 0] = 1.0 # Default to [1,0,0]
nearly_x = torch.abs(z_axis[:, 0]) >= 0.9
x_base[nearly_x, 0] = 0.0
x_base[nearly_x, 1] = 1.0 # Use [0,1,0] when z is near x-axis
# Gram-Schmidt: x = x_base - (x_base . z) * z
x_axis = x_base - (x_base * z_axis).sum(dim=-1, keepdim=True) * z_axis
x_axis = x_axis / torch.linalg.norm(x_axis, dim=-1, keepdim=True).clamp(min=1e-10)
y_axis = torch.linalg.cross(z_axis, x_axis)
# Build rotation matrices: R = [x, y, z] as columns
R = torch.stack([x_axis, y_axis, z_axis], dim=-1) # (n_rings, 3, 3)
# Now place all ring atoms at once
# ring_members: (n_rings, max_ring_size) - atom indices
# ring_local_coords: (n_rings, max_ring_size, 3) - local offsets
# We need to transform each local offset by the ring's rotation matrix
# Get valid ring member mask (exclude padding -1 and anchor atoms)
valid_mask = self.ring_members >= 0 # (n_rings, max_ring_size)
# Also exclude anchors (they're already placed)
anchor_expanded = self.ring_anchor_atoms.unsqueeze(1) # (n_rings, 1)
valid_mask = valid_mask & (self.ring_members != anchor_expanded)
# Transform local coords to global: global = R @ local + anchor_pos
# local_coords: (n_rings, max_ring_size, 3)
# R: (n_rings, 3, 3)
# anchor_pos: (n_rings, 3)
# Batch matrix multiply: (n_rings, max_ring_size, 3) @ (n_rings, 3, 3).T
# -> (n_rings, max_ring_size, 3)
global_offsets = torch.bmm(
self.ring_local_coords,
R.transpose(-1, -2)
) # (n_rings, max_ring_size, 3)
global_pos = global_offsets + anchor_pos.unsqueeze(1) # (n_rings, max_ring_size, 3)
# Scatter the results back to xyz
# Flatten valid entries and their positions
ring_indices, local_indices = torch.where(valid_mask)
atom_indices = self.ring_members[ring_indices, local_indices]
positions = global_pos[ring_indices, local_indices]
xyz[atom_indices] = positions
return xyz
[docs]
def forward_slow(self) -> torch.Tensor:
"""
Reconstruct Cartesian xyz from internal coordinates.
Fully vectorized - processes each depth level in parallel.
Only log(max_depth) sequential steps required.
Returns
-------
torch.Tensor
Reconstructed Cartesian coordinates of shape (N, 3).
"""
xyz = torch.zeros(
self.n_atoms, 3, dtype=self._dtype, device=self._device
)
# Place chain roots (vectorized) with chain orientations
# Apply rotation to initial positions if needed
xyz[self.chain_roots] = self.chain_positions
# Place second atoms of each chain (vectorized)
xyz = self._place_second_atoms(xyz)
# Place third atoms of each chain (vectorized)
xyz = self._place_third_atoms(xyz)
# Place remaining atoms by depth using pre-computed indices (fast path)
for depth_idx in range(len(self._depth_atom_indices)):
xyz = self._place_atoms_at_depth_fast(xyz, depth_idx)
# Place rigid ring atoms (vectorized)
xyz = self._place_rigid_rings(xyz)
# Apply frozen coordinates
# For atoms that are not refinable (frozen), use fixed_xyz
if not self.refinable_mask.all():
frozen_mask = ~self.refinable_mask
xyz[frozen_mask] = self.fixed_xyz[frozen_mask]
return self._to_output(xyz)
[docs]
def forward(self) -> torch.Tensor:
"""
Reconstruct Cartesian xyz from internal coordinates.
Uses optimized parallel scan method for efficiency.
Returns
-------
torch.Tensor
Reconstructed Cartesian coordinates of shape (N, 3),
on the configured output device.
"""
return self.forward_parallel()
[docs]
def shake(self, magnitude: float = 0.1) -> torch.Tensor:
"""
Add Gaussian noise to internal parameters (fully vectorized).
All operations are batched tensor ops - no loops.
Parameters
----------
magnitude : float, optional
Standard deviation of Gaussian noise. Default is 0.1.
For bond lengths, this is in Angstroms.
For angles and torsions, this is in radians.
Returns
-------
torch.Tensor
New Cartesian coordinates after perturbation.
"""
with torch.no_grad():
# Perturb bond lengths (vectorized)
if self.bond_lengths.numel() > 0:
self.bond_lengths.data += (
torch.randn_like(self.bond_lengths) * magnitude
)
self.bond_lengths.data = torch.clamp(self.bond_lengths.data, min=0.5)
# Perturb angles (vectorized)
if self.angles.numel() > 0:
self.angles.data += torch.randn_like(self.angles) * magnitude
self.angles.data = torch.clamp(
self.angles.data, min=0.1, max=torch.pi - 0.1
)
# Perturb torsions with wrap (vectorized)
if self.torsions.numel() > 0:
self.torsions.data += torch.randn_like(self.torsions) * magnitude
self.torsions.data = torch.atan2(
torch.sin(self.torsions.data), torch.cos(self.torsions.data)
)
# Perturb chain positions (vectorized)
if self.chain_positions.numel() > 0:
self.chain_positions.data += (
torch.randn_like(self.chain_positions) * magnitude
)
return self.forward()
# ===== Freeze/Unfreeze Methods (MixedTensor-style interface) =====
[docs]
def fix(
self,
selection: Union[torch.Tensor, slice, None] = None,
freeze_at_current: bool = True
) -> None:
"""
Fix (freeze) atoms to use fixed xyz coordinates instead of internal coordinates.
Fixed atoms will not be updated during reconstruction from internal coordinates.
Their positions will remain at the stored fixed_xyz values.
Parameters
----------
selection : torch.Tensor, slice, or None
Boolean mask (shape n_atoms) or indices of atoms to fix.
If None, fixes all atoms.
freeze_at_current : bool, optional
If True (default), store current reconstructed xyz for the selected atoms.
If False, use the existing fixed_xyz values.
"""
if selection is None:
selection = slice(None)
# Convert boolean mask or indices to proper indexing
if isinstance(selection, torch.Tensor) and selection.dtype == torch.bool:
mask = selection
elif isinstance(selection, torch.Tensor):
mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device)
mask[selection] = True
elif isinstance(selection, slice):
mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device)
mask[selection] = True
else:
raise TypeError(f"selection must be Tensor, slice, or None, got {type(selection)}")
if freeze_at_current:
# Compute current coordinates and store them.
# ``forward()`` migrates output to ``_output_device``; ``fixed_xyz``
# lives on CPU, so move the snapshot back to CPU before assigning.
with torch.no_grad():
current_xyz = self.forward().to(self.fixed_xyz.device)
self.fixed_xyz[mask] = current_xyz[mask]
# Mark atoms as not refinable (frozen)
self.refinable_mask[mask] = False
[docs]
def freeze(
self,
selection: Union[torch.Tensor, slice, None] = None,
freeze_at_current: bool = True
) -> None:
"""
Alias for fix(). Freeze atoms to use fixed xyz coordinates.
See fix() for full documentation.
"""
self.fix(selection, freeze_at_current)
[docs]
def refine(
self,
selection: Union[torch.Tensor, slice, None] = None,
rebuild: bool = True
) -> None:
"""
Make atoms refinable by computing their positions from internal coordinates.
This unfreezes atoms, meaning their positions will be computed from
bond lengths, angles, and torsions during forward pass.
Parameters
----------
selection : torch.Tensor, slice, or None
Boolean mask (shape n_atoms) or indices of atoms to make refinable.
If None, makes all atoms refinable.
rebuild : bool, optional
If True (default), rebuild internal coordinates from current fixed_xyz
for the selected atoms. This ensures the internal coordinates match
the current atom positions before unfreezing.
"""
if selection is None:
selection = slice(None)
# Convert boolean mask or indices to proper indexing
if isinstance(selection, torch.Tensor) and selection.dtype == torch.bool:
mask = selection
elif isinstance(selection, torch.Tensor):
mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device)
mask[selection] = True
elif isinstance(selection, slice):
mask = torch.zeros(self.n_atoms, dtype=torch.bool, device=self._device)
mask[selection] = True
else:
raise TypeError(f"selection must be Tensor, slice, or None, got {type(selection)}")
if rebuild:
# Update internal coordinates for selected atoms from fixed_xyz
self._update_internal_coords_from_xyz(self.fixed_xyz, mask)
# Mark atoms as refinable
self.refinable_mask[mask] = True
[docs]
def unfreeze(
self,
selection: Union[torch.Tensor, slice, None] = None,
rebuild: bool = True
) -> None:
"""
Alias for refine(). Unfreeze atoms to use internal coordinates.
See refine() for full documentation.
"""
self.refine(selection, rebuild)
[docs]
def fix_all(self, freeze_at_current: bool = True) -> None:
"""
Fix (freeze) all atoms.
Parameters
----------
freeze_at_current : bool, optional
If True (default), store current reconstructed xyz for all atoms.
"""
self.fix(None, freeze_at_current)
[docs]
def freeze_all(self, freeze_at_current: bool = True) -> None:
"""
Alias for fix_all(). Freeze all atoms.
"""
self.fix_all(freeze_at_current)
[docs]
def refine_all(self, rebuild: bool = True) -> None:
"""
Make all atoms refinable.
Parameters
----------
rebuild : bool, optional
If True (default), rebuild internal coordinates from current fixed_xyz.
"""
self.refine(None, rebuild)
[docs]
def unfreeze_all(self, rebuild: bool = True) -> None:
"""
Alias for refine_all(). Unfreeze all atoms.
"""
self.refine_all(rebuild)
def _update_internal_coords_from_xyz(
self,
xyz: torch.Tensor,
mask: torch.Tensor
) -> None:
"""
Update internal coordinate parameters from xyz for masked atoms.
This re-extracts bond lengths, angles, and torsions from the provided
xyz coordinates for atoms where mask is True.
Parameters
----------
xyz : torch.Tensor
Atomic coordinates of shape (N, 3).
mask : torch.Tensor
Boolean mask of shape (N,) indicating which atoms to update.
"""
# Update bond lengths for masked atoms with depth >= 1
bond_update_mask = mask & (self.atom_depth >= 1)
if bond_update_mask.any():
child_xyz = xyz[bond_update_mask]
parent_xyz = xyz[self.parent_indices[bond_update_mask]]
new_bonds = torch.linalg.norm(child_xyz - parent_xyz, dim=-1)
bond_idx = self.bond_param_indices[bond_update_mask]
with torch.no_grad():
self.bond_lengths.data[bond_idx] = new_bonds
# Update angles for masked atoms with depth >= 2
angle_update_mask = mask & (self.atom_depth >= 2)
if angle_update_mask.any():
child_xyz = xyz[angle_update_mask]
parent_xyz = xyz[self.parent_indices[angle_update_mask]]
grandparent_xyz = xyz[self.grandparent_indices[angle_update_mask]]
v1 = child_xyz - parent_xyz
v2 = grandparent_xyz - parent_xyz
cos_angles = torch.sum(v1 * v2, dim=-1) / (
torch.linalg.norm(v1, dim=-1) * torch.linalg.norm(v2, dim=-1) + 1e-10
)
cos_angles = torch.clamp(cos_angles, -1.0, 1.0)
new_angles = torch.acos(cos_angles)
angle_idx = self.angle_param_indices[angle_update_mask]
with torch.no_grad():
self.angles.data[angle_idx] = new_angles
# Update torsions for masked atoms with depth >= 3
torsion_update_mask = mask & (self.atom_depth >= 3)
if torsion_update_mask.any():
child_xyz = xyz[torsion_update_mask]
parent_xyz = xyz[self.parent_indices[torsion_update_mask]]
grandparent_xyz = xyz[self.grandparent_indices[torsion_update_mask]]
great_grandparent_xyz = xyz[self.great_grandparent_indices[torsion_update_mask]]
new_torsions = self._compute_torsion_angles(
great_grandparent_xyz, grandparent_xyz, parent_xyz, child_xyz
)
torsion_idx = self.torsion_param_indices[torsion_update_mask]
with torch.no_grad():
self.torsions.data[torsion_idx] = new_torsions
# Update chain positions if any chain roots are in the mask
for chain_idx in range(self.n_chains):
root = self.chain_roots[chain_idx]
if mask[root]:
with torch.no_grad():
self.chain_positions.data[chain_idx] = xyz[root]
@property
def n_refinable(self) -> int:
"""Return the number of refinable (unfrozen) atoms."""
return self.refinable_mask.sum().item()
@property
def n_fixed(self) -> int:
"""Return the number of fixed (frozen) atoms."""
return (~self.refinable_mask).sum().item()
def __repr__(self) -> str:
return (
f"InternalCoordinateTensor(n_atoms={self.n_atoms}, "
f"n_refinable={self.n_refinable}, n_fixed={self.n_fixed}, "
f"n_chains={self.n_chains}, n_bonds={self.n_bonds}, "
f"n_angles={self.n_angles}, n_torsions={self.n_torsions}, "
f"n_rings={self.n_rings}, max_depth={self.max_depth}, "
f"device={self._device})"
)
# ===== Parallel Scan Methods for Optimized Forward Pass =====
def _find_longest_path(self) -> list:
"""
Find the longest path (backbone) in the spanning tree.
Returns list of atom indices from root to leaf.
"""
n_atoms = self.n_atoms
parent = self.parent_indices
# Count children for each atom
children = [[] for _ in range(n_atoms)]
for i in range(n_atoms):
p = parent[i].item()
if p >= 0:
children[p].append(i)
# Find leaves
leaves = [i for i in range(n_atoms) if len(children[i]) == 0]
# Find longest path from any leaf to root
longest_path = []
for leaf in leaves:
path = [leaf]
current = leaf
while parent[current].item() >= 0:
current = parent[current].item()
path.append(current)
path.reverse()
if len(path) > len(longest_path):
longest_path = path
return longest_path
def _compute_delta_transform(
self, d: torch.Tensor, theta: torch.Tensor, phi: torch.Tensor
) -> tuple:
"""
Compute delta transforms for NeRF in local frame.
For each atom, computes (R_delta, t_delta) where:
- t_delta is the local position of the atom in parent's frame
- R_delta is the rotation from parent's frame to child's frame
Parameters
----------
d : torch.Tensor
Bond lengths, shape (N,)
theta : torch.Tensor
Bond angles in radians, shape (N,)
phi : torch.Tensor
Torsion angles in radians, shape (N,)
Returns
-------
R : torch.Tensor
Rotation matrices, shape (N, 3, 3)
t : torch.Tensor
Translation vectors, shape (N, 3)
"""
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
cos_phi = torch.cos(phi)
sin_phi = torch.sin(phi)
zeros = torch.zeros_like(theta)
# Rotation matrix columns in parent's (m, n, bc) frame
col0 = torch.stack([-cos_phi * cos_theta, sin_phi * cos_theta, -sin_theta], dim=-1)
col1 = torch.stack([sin_phi, cos_phi, zeros], dim=-1)
col2 = torch.stack([sin_theta * cos_phi, -sin_theta * sin_phi, -cos_theta], dim=-1)
R = torch.stack([col0, col1, col2], dim=-1) # (N, 3, 3)
t = d.unsqueeze(-1) * col2 # (N, 3)
return R, t
def _compose_transforms(
self,
R1: torch.Tensor, t1: torch.Tensor,
R2: torch.Tensor, t2: torch.Tensor
) -> tuple:
"""
Compose rigid transforms: T1 * T2 = (R1 @ R2, R1 @ t2 + t1)
"""
R = R1 @ R2
t = (R1 @ t2.unsqueeze(-1)).squeeze(-1) + t1
return R, t
def _parallel_scan(
self, R_deltas: torch.Tensor, t_deltas: torch.Tensor
) -> tuple:
"""
Parallel prefix scan of rigid transforms (Hillis-Steele style).
Computes G[i] = T_0 * T_1 * ... * T_{i-1} for each i.
Parameters
----------
R_deltas : torch.Tensor
Per-step rotations, shape (N, 3, 3)
t_deltas : torch.Tensor
Per-step translations, shape (N, 3)
Returns
-------
R_cum : torch.Tensor
Cumulative rotations, shape (N+1, 3, 3)
t_cum : torch.Tensor
Cumulative translations, shape (N+1, 3)
"""
N = R_deltas.shape[0]
device = R_deltas.device
dtype = R_deltas.dtype
if N == 0:
R_cum = torch.eye(3, dtype=dtype, device=device).unsqueeze(0)
t_cum = torch.zeros(1, 3, dtype=dtype, device=device)
return R_cum, t_cum
# Work with N+1 elements (element 0 is identity)
# Use double buffering to avoid clone() overhead
R = torch.zeros(N + 1, 3, 3, dtype=dtype, device=device)
t = torch.zeros(N + 1, 3, dtype=dtype, device=device)
R_buf = torch.zeros(N + 1, 3, 3, dtype=dtype, device=device)
t_buf = torch.zeros(N + 1, 3, dtype=dtype, device=device)
R[0] = torch.eye(3, dtype=dtype, device=device)
R[1:] = R_deltas
t[1:] = t_deltas
# Hillis-Steele inclusive scan with double buffering
offset = 1
use_buf = False
while offset < N + 1:
# Select source and destination buffers
if use_buf:
R_src_buf, t_src_buf = R_buf, t_buf
R_dst_buf, t_dst_buf = R, t
else:
R_src_buf, t_src_buf = R, t
R_dst_buf, t_dst_buf = R_buf, t_buf
# Copy unchanged elements
R_dst_buf[:offset] = R_src_buf[:offset]
t_dst_buf[:offset] = t_src_buf[:offset]
# Compute and store updated elements
idx = torch.arange(offset, N + 1, device=device)
idx_src = idx - offset
R_src = R_src_buf[idx_src]
t_src = t_src_buf[idx_src]
R_dst = R_src_buf[idx]
t_dst = t_src_buf[idx]
R_new, t_new = self._compose_transforms(R_src, t_src, R_dst, t_dst)
R_dst_buf[idx] = R_new
t_dst_buf[idx] = t_new
use_buf = not use_buf
offset *= 2
# Return the buffer that has the final result
if use_buf:
return R_buf, t_buf
else:
return R, t
def _build_backbone_data(self):
"""
Pre-compute backbone and side chain data for parallel forward.
This identifies the longest path (backbone), computes which atoms
are side chains, and determines their branching points.
"""
# Find the longest path
backbone = self._find_longest_path()
backbone_set = set(backbone)
# Count children for each atom
children = [[] for _ in range(self.n_atoms)]
for i in range(self.n_atoms):
p = self.parent_indices[i].item()
if p >= 0:
children[p].append(i)
# Find side chain atoms (not on backbone)
side_chain_atoms = [i for i in range(self.n_atoms) if i not in backbone_set]
# Compute side chain depths (depth from branching point)
side_chain_depth = torch.full((self.n_atoms,), -1, dtype=torch.long, device=self._device)
side_chain_depth[backbone] = 0 # Backbone atoms have "depth 0" from themselves
# For each backbone atom, find non-backbone children and do BFS
max_sc_depth = 0
for bp_atom in backbone:
non_backbone_children = [c for c in children[bp_atom] if c not in backbone_set]
for start in non_backbone_children:
queue = [(start, 1)]
while queue:
atom, depth = queue.pop(0)
side_chain_depth[atom] = depth
max_sc_depth = max(max_sc_depth, depth)
for child in children[atom]:
queue.append((child, depth + 1))
# Store as buffers
self.register_buffer(
"_backbone",
torch.tensor(backbone, dtype=torch.long, device=self._device)
)
self.register_buffer("_side_chain_depth", side_chain_depth)
self._max_side_chain_depth = max_sc_depth
self._backbone_set = backbone_set
# Pre-compute backbone internal coordinates indices
if len(backbone) > 3:
self._backbone_bond_idx = []
self._backbone_angle_idx = []
self._backbone_torsion_idx = []
for i in range(3, len(backbone)):
atom_idx = backbone[i]
bond_idx = self.bond_param_indices[atom_idx].item()
angle_idx = self.angle_param_indices[atom_idx].item()
torsion_idx = self.torsion_param_indices[atom_idx].item()
self._backbone_bond_idx.append(bond_idx)
self._backbone_angle_idx.append(angle_idx)
self._backbone_torsion_idx.append(torsion_idx)
self.register_buffer(
"_backbone_bond_idx_tensor",
torch.tensor(self._backbone_bond_idx, dtype=torch.long, device=self._device)
)
self.register_buffer(
"_backbone_angle_idx_tensor",
torch.tensor(self._backbone_angle_idx, dtype=torch.long, device=self._device)
)
self.register_buffer(
"_backbone_torsion_idx_tensor",
torch.tensor(self._backbone_torsion_idx, dtype=torch.long, device=self._device)
)
# Pre-compute side chain depth indices (similar to _build_depth_indices but for SC depth)
self._sc_depth_atom_indices = []
self._sc_depth_parent_indices = []
self._sc_depth_gp_indices = []
self._sc_depth_ggp_indices = []
self._sc_depth_bond_idx = []
self._sc_depth_angle_idx = []
self._sc_depth_torsion_idx = []
for d in range(1, max_sc_depth + 1):
mask = (side_chain_depth == d) & (self.atom_depth >= 3)
if mask.any():
atom_idx = torch.where(mask)[0]
self._sc_depth_atom_indices.append(atom_idx)
self._sc_depth_parent_indices.append(self.parent_indices[atom_idx])
self._sc_depth_gp_indices.append(self.grandparent_indices[atom_idx])
self._sc_depth_ggp_indices.append(self.great_grandparent_indices[atom_idx])
self._sc_depth_bond_idx.append(self.bond_param_indices[atom_idx])
self._sc_depth_angle_idx.append(self.angle_param_indices[atom_idx])
self._sc_depth_torsion_idx.append(self.torsion_param_indices[atom_idx])
else:
empty = torch.tensor([], dtype=torch.long, device=self._device)
self._sc_depth_atom_indices.append(empty)
self._sc_depth_parent_indices.append(empty)
self._sc_depth_gp_indices.append(empty)
self._sc_depth_ggp_indices.append(empty)
self._sc_depth_bond_idx.append(empty)
self._sc_depth_angle_idx.append(empty)
self._sc_depth_torsion_idx.append(empty)
# Pre-compute merged indices for batched side chain placement
# Merge all SC atoms (depth >= 3) into single tensors for efficiency
all_sc_atoms = []
all_sc_parents = []
all_sc_gp = []
all_sc_ggp = []
all_sc_bond_idx = []
all_sc_angle_idx = []
all_sc_torsion_idx = []
all_sc_depths = []
for d in range(max_sc_depth):
if self._sc_depth_atom_indices[d].numel() > 0:
all_sc_atoms.append(self._sc_depth_atom_indices[d])
all_sc_parents.append(self._sc_depth_parent_indices[d])
all_sc_gp.append(self._sc_depth_gp_indices[d])
all_sc_ggp.append(self._sc_depth_ggp_indices[d])
all_sc_bond_idx.append(self._sc_depth_bond_idx[d])
all_sc_angle_idx.append(self._sc_depth_angle_idx[d])
all_sc_torsion_idx.append(self._sc_depth_torsion_idx[d])
all_sc_depths.append(torch.full_like(self._sc_depth_atom_indices[d], d + 1))
if all_sc_atoms:
self.register_buffer("_all_sc_atoms", torch.cat(all_sc_atoms))
self.register_buffer("_all_sc_parents", torch.cat(all_sc_parents))
self.register_buffer("_all_sc_gp", torch.cat(all_sc_gp))
self.register_buffer("_all_sc_ggp", torch.cat(all_sc_ggp))
self.register_buffer("_all_sc_bond_idx", torch.cat(all_sc_bond_idx))
self.register_buffer("_all_sc_angle_idx", torch.cat(all_sc_angle_idx))
self.register_buffer("_all_sc_torsion_idx", torch.cat(all_sc_torsion_idx))
self.register_buffer("_all_sc_depths", torch.cat(all_sc_depths))
else:
empty = torch.tensor([], dtype=torch.long, device=self._device)
self.register_buffer("_all_sc_atoms", empty)
self.register_buffer("_all_sc_parents", empty)
self.register_buffer("_all_sc_gp", empty)
self.register_buffer("_all_sc_ggp", empty)
self.register_buffer("_all_sc_bond_idx", empty)
self.register_buffer("_all_sc_angle_idx", empty)
self.register_buffer("_all_sc_torsion_idx", empty)
self.register_buffer("_all_sc_depths", empty)
def _place_atoms_nerf(
self, xyz: torch.Tensor, atom_idx: torch.Tensor,
parent_idx: torch.Tensor, gp_idx: torch.Tensor, ggp_idx: torch.Tensor,
bond_idx: torch.Tensor, angle_idx: torch.Tensor, torsion_idx: torch.Tensor
) -> torch.Tensor:
"""
Place atoms using NeRF algorithm (vectorized helper).
Parameters
----------
xyz : torch.Tensor
Current coordinates.
atom_idx, parent_idx, gp_idx, ggp_idx : torch.Tensor
Atom indices.
bond_idx, angle_idx, torsion_idx : torch.Tensor
Parameter indices.
Returns
-------
torch.Tensor
Updated coordinates.
"""
p1 = xyz[ggp_idx]
p2 = xyz[gp_idx]
p3 = xyz[parent_idx]
d = self.bond_lengths[bond_idx]
theta = self.angles[angle_idx]
phi = self.torsions[torsion_idx]
# Build local frame
bc = p3 - p2
bc = bc / torch.linalg.norm(bc, dim=-1, keepdim=True).clamp(min=1e-10)
ab = p2 - p1
n = torch.linalg.cross(ab, bc)
n = n / torch.linalg.norm(n, dim=-1, keepdim=True).clamp(min=1e-10)
m = torch.linalg.cross(n, bc)
# Compute positions
theta_internal = torch.pi - theta
sin_theta = torch.sin(theta_internal)
cos_theta = torch.cos(theta_internal)
dx = d * cos_theta
dy = d * sin_theta * torch.cos(phi)
dz = d * sin_theta * torch.sin(phi)
new_pos = (
p3
+ dx.unsqueeze(-1) * bc
+ dy.unsqueeze(-1) * m
- dz.unsqueeze(-1) * n
)
xyz[atom_idx] = new_pos
return xyz
def _place_side_chains_batched(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Place side chain atoms using batched depth processing.
Optimizes by using pre-computed indices and a helper function.
Parameters
----------
xyz : torch.Tensor
Current coordinates with backbone already placed.
Returns
-------
torch.Tensor
Updated coordinates with side chains placed.
"""
# Use the original per-depth lists which are already optimally pre-computed
for sc_depth_idx in range(self._max_side_chain_depth):
atom_idx = self._sc_depth_atom_indices[sc_depth_idx]
if atom_idx.numel() == 0:
continue
xyz = self._place_atoms_nerf(
xyz, atom_idx,
self._sc_depth_parent_indices[sc_depth_idx],
self._sc_depth_gp_indices[sc_depth_idx],
self._sc_depth_ggp_indices[sc_depth_idx],
self._sc_depth_bond_idx[sc_depth_idx],
self._sc_depth_angle_idx[sc_depth_idx],
self._sc_depth_torsion_idx[sc_depth_idx]
)
return xyz
[docs]
def forward_parallel(self) -> torch.Tensor:
"""
Reconstruct Cartesian xyz using parallel scan for backbone.
This is an optimized forward pass that:
1. Places backbone atoms using parallel prefix scan (O(log N) steps)
2. Places side chain atoms using depth iterations (O(max_sc_depth) steps)
For deep trees where backbone is long but side chains are short,
this can be significantly faster than the standard forward().
Returns
-------
torch.Tensor
Reconstructed Cartesian coordinates of shape (N, 3).
"""
# Build backbone data if not already done
if not hasattr(self, '_backbone'):
self._build_backbone_data()
xyz = torch.zeros(
self.n_atoms, 3, dtype=self._dtype, device=self._device
)
backbone = self._backbone
n_backbone = len(backbone)
# === Phase 1: Place first 3 backbone atoms ===
# Use standard placement for atoms at depth 0, 1, 2
xyz[self.chain_roots] = self.chain_positions
xyz = self._place_second_atoms(xyz)
xyz = self._place_third_atoms(xyz)
# === Phase 2: Place remaining backbone atoms via parallel scan ===
if n_backbone > 3:
# Get internal coords for backbone atoms at depth >= 3
d = self.bond_lengths[self._backbone_bond_idx_tensor]
theta = self.angles[self._backbone_angle_idx_tensor]
phi = self.torsions[self._backbone_torsion_idx_tensor]
# Compute delta transforms
R_deltas, t_deltas = self._compute_delta_transform(d, theta, phi)
# Parallel scan to get cumulative transforms
R_cum, t_cum = self._parallel_scan(R_deltas, t_deltas)
# Initial frame at backbone atom 2 (parent of first depth-3 backbone atom)
p0 = xyz[backbone[0]]
p1 = xyz[backbone[1]]
p2 = xyz[backbone[2]]
bc_init = (p2 - p1) / torch.linalg.norm(p2 - p1).clamp(min=1e-10)
ab_init = p1 - p0
n_init = torch.linalg.cross(ab_init, bc_init)
n_init = n_init / torch.linalg.norm(n_init).clamp(min=1e-10)
m_init = torch.linalg.cross(n_init, bc_init)
R_init = torch.stack([m_init, n_init, bc_init], dim=-1)
t_init = p2
# Place backbone atoms 3, 4, ..., n_backbone-1 (vectorized)
# Position = R_init @ t_cum[i+1] + t_init for all i at once
local_positions = t_cum[1:n_backbone - 2] # (n_backbone-3, 3)
global_positions = (R_init @ local_positions.T).T + t_init # (n_backbone-3, 3)
xyz[backbone[3:]] = global_positions
# === Phase 3: Place side chain atoms at depth 1, 2 (special handling) ===
# These are atoms with sc_depth == 1 or 2 that need the standard method
# for atoms at tree depth 1 or 2
# Actually, _place_second_atoms and _place_third_atoms handle all depth-1 and depth-2
# atoms regardless of whether they're on backbone or side chains.
# So we just need to handle side chain atoms at depth >= 3.
# === Phase 4: Place side chain atoms by SC-depth ===
# Use batched placement for efficiency
xyz = self._place_side_chains_batched(xyz)
# === Phase 5: Place rigid rings ===
xyz = self._place_rigid_rings(xyz)
# === Phase 6: Apply frozen coordinates ===
# For atoms that are not refinable (frozen), use fixed_xyz
if not self.refinable_mask.all():
frozen_mask = ~self.refinable_mask
xyz[frozen_mask] = self.fixed_xyz[frozen_mask]
return self._to_output(xyz)