torchref.base.chain_closure package

Chain closure submodule for differentiable analytical junction closure.

Provides backbone identification, junction planning, NeRF-based forward kinematics, Newton-based closure solving, and IFT-based gradient computation.

torchref.base.chain_closure.identify_backbone_atoms(pdb)[source]

Map (chainid, resseq) to backbone atom indices {N: idx, CA: idx, C: idx}.

Parameters:

pdb (pd.DataFrame) – PDB DataFrame with columns ‘chainid’, ‘resseq’, ‘name’, ‘index’, ‘resname’.

Returns:

Mapping from (chainid, resseq) to dict of atom name -> atom index for backbone atoms N, CA, C. Only residues with all three atoms present are included.

Return type:

dict

torchref.base.chain_closure.get_chain_residues(pdb)[source]

Get ordered list of protein residue keys per chain.

Parameters:

pdb (pd.DataFrame) – PDB DataFrame.

Returns:

Mapping from chainid to sorted list of (chainid, resseq) tuples.

Return type:

dict

torchref.base.chain_closure.compute_backbone_torsions(xyz, backbone_map, chain_residues)[source]

Compute phi, psi, omega torsion angles for each residue.

Parameters:
  • xyz (torch.Tensor) – Atomic coordinates of shape (N, 3).

  • backbone_map (dict) – From identify_backbone_atoms().

  • chain_residues (dict) – From get_chain_residues().

Returns:

Mapping from (chainid, resseq) to {‘phi’: float, ‘psi’: float, ‘omega’: float}. Values are in radians. Missing angles are set to NaN.

Return type:

dict

torchref.base.chain_closure.estimate_secondary_structure(torsions)[source]

Simple Ramachandran region classification: H (helix), E (sheet), L (loop).

Parameters:

torsions (dict) – From compute_backbone_torsions().

Returns:

Mapping from (chainid, resseq) to ‘H’, ‘E’, or ‘L’.

Return type:

dict

torchref.base.chain_closure.plan_junction_placement(chain_residues, backbone_map, n_aa_per_segment=18, junction_size=3, ss=None, prefer_loops=True)[source]

Plan segment and junction placement along protein chains.

Divides each chain into segments of ~n_aa_per_segment residues with junction_size-residue junctions between them. Optionally slides junctions to prefer loop regions.

The algorithm: 1. Determine nominal junction positions at every n_aa_per_segment residues. 2. Optionally slide each junction within +-slide_range to prefer loops. 3. Build segments from the non-junction gaps between junctions.

Parameters:
  • chain_residues (dict) – From get_chain_residues().

  • backbone_map (dict) – From identify_backbone_atoms().

  • n_aa_per_segment (int) – Target number of residues per free-DOF segment.

  • junction_size (int) – Number of residues per junction (slave DOFs).

  • ss (dict, optional) – Secondary structure assignments from estimate_secondary_structure().

  • prefer_loops (bool) – If True and ss is provided, slide junctions to prefer loop regions.

Returns:

  • segments (list of list) – Each inner list contains (chainid, resseq) keys for one segment.

  • junctions (list of list) – Each inner list contains (chainid, resseq) keys for one junction. Junction i connects segment i to segment i+1.

Return type:

Tuple[List[List[Tuple[str, int]]], List[List[Tuple[str, int]]]]

torchref.base.chain_closure.get_junction_backbone_indices(junction_residues, backbone_map)[source]

Get ordered backbone atom indices for junction residues.

Parameters:
  • junction_residues (list) – List of (chainid, resseq) tuples for the junction.

  • backbone_map (dict) – From identify_backbone_atoms().

Returns:

List of dicts with ‘N’, ‘CA’, ‘C’ atom indices, one per residue.

Return type:

list

Raises:

ValueError – If any junction residue lacks backbone atoms.

torchref.base.chain_closure.backbone_fk_junction(p1_start, p2_start, p3_start, phi_psi, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length=None, post_bond_angle=None)[source]

Forward kinematics through junction backbone using NeRF.

Optionally extends the FK by one atom (N_post) using psi(K-1), which enables the closure to target the post-junction N position dynamically.

Parameters:
  • p1_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).

  • p2_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).

  • p3_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).

  • phi_psi (torch.Tensor) – Slave DOFs: [phi_0, psi_0, phi_1, psi_1, …], shape (2*K,) or (J, 2*K) where K = junction_size.

  • nerf_bond_lengths (torch.Tensor) – Bond lengths in NeRF order: [C_prev-N_0, N_0-CA_0, CA_0-C_0, …], shape (3*K,) or (J, 3*K).

  • nerf_bond_angles (torch.Tensor) – Bond angles at pivot atoms in NeRF order: [angle_at_C_prev_for_N, angle_at_N_for_CA, angle_at_CA_for_C, …], shape (3*K,) or (J, 3*K).

  • omega (torch.Tensor) – Fixed omega angles, shape (K,) or (J, K).

  • psi_prev (torch.Tensor) – Fixed psi of pre-junction residue (used to place first N), shape () or (J,).

  • post_bond_length (torch.Tensor, optional) – C(K-1) -> N_post bond length, shape (J,). If provided together with post_bond_angle, the FK extends one more atom using psi(K-1).

  • post_bond_angle (torch.Tensor, optional) – Angle at C(K-1) for N_post placement, shape (J,).

Returns:

  • end_p1, end_p2, end_p3 (torch.Tensor) – N, CA, C of last junction residue.

  • backbone_xyz (torch.Tensor) – All junction backbone positions, shape (3*K, 3) or (J, 3*K, 3).

  • end_point (torch.Tensor) – Closure target point: N_post if post params given, else C(K-1). Shape (3,) or (J, 3).

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

torchref.base.chain_closure.closure_residual(end_p3, target_p3)[source]

Compute the 3D position closure residual.

Parameters:
  • end_p3 (torch.Tensor) – Computed C position of last junction residue, shape (3,) or (J, 3).

  • target_p3 (torch.Tensor) – Target C position, shape (3,) or (J, 3).

Returns:

3D position residual, shape (3,) or (J, 3).

Return type:

torch.Tensor

class torchref.base.chain_closure.JunctionClosure(*args, **kwargs)[source]

Bases: Function

Custom autograd for chain closure using the Implicit Function Theorem.

Forward: Newton solve for junction phi/psi. Backward: IFT adjoint for exact gradients.

static forward(ctx, phi_psi_init, p1_start, p2_start, p3_start, target_p3, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length, post_bond_angle, max_iter=20, tol=0.0001, tikhonov_eps=1e-06)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static backward(ctx, grad_output)[source]

IFT-based backward pass for the implicit closure constraint.

Given F(phi_psi*, theta) = 0, the IFT gives:

d(phi_psi*)/d(theta) = -[dF/d(phi_psi)]^+ @ dF/d(theta)

where ^+ denotes the pseudoinverse.

The VJP is:
grad_theta = grad_output^T @ d(phi_psi*)/d(theta)

= -lam_3d^T @ dF/d(theta)

where lam_3d = (J_phi @ J_phi^T)^{-1} @ J_phi @ grad_output (adjoint).

class torchref.base.chain_closure.JunctionSolver(n_junctions, junction_size, initial_phi_psi, max_iter=20, tol=0.0001, tikhonov_eps=1e-06, dtype=None, device=None)[source]

Bases: DeviceMixin, Module

Manages warm-start buffers and solves junction closures.

Parameters:
  • n_junctions (int) – Number of junctions.

  • junction_size (int) – Number of residues per junction.

  • initial_phi_psi (torch.Tensor) – Initial guess for phi/psi, shape (J, 2*K).

  • max_iter (int) – Maximum Newton iterations.

  • tol (float) – Convergence tolerance.

__init__(n_junctions, junction_size, initial_phi_psi, max_iter=20, tol=0.0001, tikhonov_eps=1e-06, dtype=None, device=None)[source]
forward(p1_start, p2_start, p3_start, target_p3, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length=None, post_bond_angle=None)[source]

Solve junction closures and return backbone positions.

Returns:

  • phi_psi (torch.Tensor) – Solved phi/psi, shape (J, 2*K).

  • backbone_xyz (torch.Tensor) – Junction backbone positions, shape (J, 3*K, 3).

Return type:

Tuple[Tensor, Tensor]

property closure_residuals: Tensor | None

Submodules