torchref.model.internal_coordinates module

Internal coordinate parametrization for atomic structures.

This module provides the InternalCoordinateTensor class which parametrizes atomic XYZ coordinates using internal coordinates (bond lengths, angles, torsions) instead of Cartesian coordinates. This enables physically meaningful perturbations and differentiable reconstruction.

Key features: - Bond detection: Atoms within 2Å are considered bonded (using torch.cdist) - Internal coordinate parametrization: N atoms → N-1 bonds, N-2 angles, N-3 torsions per chain - Chain handling: Unconnected chains treated as rigid groups with position/orientation - Fully differentiable: Complete gradient flow from internal params → Cartesian coords - Fully vectorized: No Python loops over atoms - all operations via tensor ops - Ring handling: Rings are treated as rigid entities with only the anchor movable

class torchref.model.internal_coordinates.InternalCoordinateTensor(initial_xyz, bond_cutoff=2.0, requires_grad=True, dtype=None, device=None)[source]

Bases: DeviceMixin, Module

Parameter wrapper using internal coordinates (Z-matrix style).

Stores: bond_lengths, angles, torsions, chain_positions, chain_orientations Reconstructs: Cartesian xyz on forward()

This provides a physically meaningful parametrization of atomic coordinates where perturbations correspond to changes in bond lengths, angles, and torsion angles rather than arbitrary Cartesian displacements.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.

  • device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.

n_atoms

Number of atoms.

Type:

int

n_chains

Number of disconnected chains.

Type:

int

max_depth

Maximum depth in the spanning tree.

Type:

int

bond_lengths

Bond length parameters in Angstroms.

Type:

nn.Parameter

angles

Angle parameters in radians.

Type:

nn.Parameter

torsions

Torsion angle parameters in radians.

Type:

nn.Parameter

chain_positions

Absolute positions of chain root atoms.

Type:

nn.Parameter

chain_orientations

Axis-angle orientations for each chain.

Type:

nn.Parameter

__init__(initial_xyz, bond_cutoff=2.0, requires_grad=True, dtype=None, device=None)[source]

Initialize InternalCoordinateTensor from Cartesian coordinates.

Parameters:
  • initial_xyz (torch.Tensor) – Initial Cartesian coordinates of shape (N, 3).

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0.

  • requires_grad (bool, optional) – Whether parameters should have gradients. Default is True.

  • dtype (torch.dtype, optional) – Data type for tensors. Default is same as initial_xyz.

  • device (torch.device, optional) – Device for tensors. Default is same as initial_xyz.

property dtype

Return the dtype of tensors.

property device

Logical device — where forward()’s result is delivered.

Internal parameters/buffers stay on CPU regardless; this is the device requested by the caller (e.g. via .to('mps')) and is the device the forward output is migrated to.

to(*args, **kwargs)[source]

Update output device and optionally cast dtype.

Unlike DeviceMixin.to, this does not move internal parameters/buffers to device — they stay on CPU to avoid the per-op dispatch overhead of MPS/CUDA on the sequential spanning-tree + parallel-scan code. The device argument only updates _output_device; dtype still propagates normally and recasts all CPU tensors.

cuda(device=None)[source]

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Args:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

cpu()[source]

Move all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns:

Module: self

forward_slow()[source]

Reconstruct Cartesian xyz from internal coordinates.

Fully vectorized - processes each depth level in parallel. Only log(max_depth) sequential steps required.

Returns:

Reconstructed Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor

forward()[source]

Reconstruct Cartesian xyz from internal coordinates.

Uses optimized parallel scan method for efficiency.

Returns:

Reconstructed Cartesian coordinates of shape (N, 3), on the configured output device.

Return type:

torch.Tensor

shake(magnitude=0.1)[source]

Add Gaussian noise to internal parameters (fully vectorized).

All operations are batched tensor ops - no loops.

Parameters:

magnitude (float, optional) – Standard deviation of Gaussian noise. Default is 0.1. For bond lengths, this is in Angstroms. For angles and torsions, this is in radians.

Returns:

New Cartesian coordinates after perturbation.

Return type:

torch.Tensor

fix(selection=None, freeze_at_current=True)[source]

Fix (freeze) atoms to use fixed xyz coordinates instead of internal coordinates.

Fixed atoms will not be updated during reconstruction from internal coordinates. Their positions will remain at the stored fixed_xyz values.

Parameters:
  • selection (torch.Tensor, slice, or None) – Boolean mask (shape n_atoms) or indices of atoms to fix. If None, fixes all atoms.

  • freeze_at_current (bool, optional) – If True (default), store current reconstructed xyz for the selected atoms. If False, use the existing fixed_xyz values.

freeze(selection=None, freeze_at_current=True)[source]

Alias for fix(). Freeze atoms to use fixed xyz coordinates.

See fix() for full documentation.

refine(selection=None, rebuild=True)[source]

Make atoms refinable by computing their positions from internal coordinates.

This unfreezes atoms, meaning their positions will be computed from bond lengths, angles, and torsions during forward pass.

Parameters:
  • selection (torch.Tensor, slice, or None) – Boolean mask (shape n_atoms) or indices of atoms to make refinable. If None, makes all atoms refinable.

  • rebuild (bool, optional) – If True (default), rebuild internal coordinates from current fixed_xyz for the selected atoms. This ensures the internal coordinates match the current atom positions before unfreezing.

unfreeze(selection=None, rebuild=True)[source]

Alias for refine(). Unfreeze atoms to use internal coordinates.

See refine() for full documentation.

fix_all(freeze_at_current=True)[source]

Fix (freeze) all atoms.

Parameters:

freeze_at_current (bool, optional) – If True (default), store current reconstructed xyz for all atoms.

freeze_all(freeze_at_current=True)[source]

Alias for fix_all(). Freeze all atoms.

refine_all(rebuild=True)[source]

Make all atoms refinable.

Parameters:

rebuild (bool, optional) – If True (default), rebuild internal coordinates from current fixed_xyz.

unfreeze_all(rebuild=True)[source]

Alias for refine_all(). Unfreeze all atoms.

property n_refinable: int

Return the number of refinable (unfrozen) atoms.

property n_fixed: int

Return the number of fixed (frozen) atoms.

forward_parallel()[source]

Reconstruct Cartesian xyz using parallel scan for backbone.

This is an optimized forward pass that: 1. Places backbone atoms using parallel prefix scan (O(log N) steps) 2. Places side chain atoms using depth iterations (O(max_sc_depth) steps)

For deep trees where backbone is long but side chains are short, this can be significantly faster than the standard forward().

Returns:

Reconstructed Cartesian coordinates of shape (N, 3).

Return type:

torch.Tensor