torchref.base.chain_closure.closure module

Core chain closure mathematics for junction residues.

Uses the standard NeRF (Natural Extension Reference Frame) formula for backbone forward kinematics. Slave DOFs are solved by Newton’s method with IFT-based backward gradients.

NeRF torsion assignments for placing backbone atoms: - Place N(i): torsion = psi(i-1) = dihedral(N(i-1), CA(i-1), C(i-1), N(i)) - Place CA(i): torsion = omega(i) = dihedral(CA(i-1), C(i-1), N(i), CA(i)) - Place C(i): torsion = phi(i) = dihedral(C(i-1), N(i), CA(i), C(i))

For K junction residues, the slave DOFs are phi(0..K-1) and psi(0..K-1) (2K total). omega values are fixed. psi(-1) (= psi of the last pre-junction residue) is also fixed. The residual matches the computed end-of-junction position against the known post-junction position.

torchref.base.chain_closure.closure.backbone_fk_junction(p1_start, p2_start, p3_start, phi_psi, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length=None, post_bond_angle=None)[source]

Forward kinematics through junction backbone using NeRF.

Optionally extends the FK by one atom (N_post) using psi(K-1), which enables the closure to target the post-junction N position dynamically.

Parameters:
  • p1_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).

  • p2_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).

  • p3_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).

  • phi_psi (torch.Tensor) – Slave DOFs: [phi_0, psi_0, phi_1, psi_1, …], shape (2*K,) or (J, 2*K) where K = junction_size.

  • nerf_bond_lengths (torch.Tensor) – Bond lengths in NeRF order: [C_prev-N_0, N_0-CA_0, CA_0-C_0, …], shape (3*K,) or (J, 3*K).

  • nerf_bond_angles (torch.Tensor) – Bond angles at pivot atoms in NeRF order: [angle_at_C_prev_for_N, angle_at_N_for_CA, angle_at_CA_for_C, …], shape (3*K,) or (J, 3*K).

  • omega (torch.Tensor) – Fixed omega angles, shape (K,) or (J, K).

  • psi_prev (torch.Tensor) – Fixed psi of pre-junction residue (used to place first N), shape () or (J,).

  • post_bond_length (torch.Tensor, optional) – C(K-1) -> N_post bond length, shape (J,). If provided together with post_bond_angle, the FK extends one more atom using psi(K-1).

  • post_bond_angle (torch.Tensor, optional) – Angle at C(K-1) for N_post placement, shape (J,).

Returns:

  • end_p1, end_p2, end_p3 (torch.Tensor) – N, CA, C of last junction residue.

  • backbone_xyz (torch.Tensor) – All junction backbone positions, shape (3*K, 3) or (J, 3*K, 3).

  • end_point (torch.Tensor) – Closure target point: N_post if post params given, else C(K-1). Shape (3,) or (J, 3).

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

torchref.base.chain_closure.closure.closure_residual(end_p3, target_p3)[source]

Compute the 3D position closure residual.

Parameters:
  • end_p3 (torch.Tensor) – Computed C position of last junction residue, shape (3,) or (J, 3).

  • target_p3 (torch.Tensor) – Target C position, shape (3,) or (J, 3).

Returns:

3D position residual, shape (3,) or (J, 3).

Return type:

torch.Tensor

class torchref.base.chain_closure.closure.JunctionClosure(*args, **kwargs)[source]

Bases: Function

Custom autograd for chain closure using the Implicit Function Theorem.

Forward: Newton solve for junction phi/psi. Backward: IFT adjoint for exact gradients.

static forward(ctx, phi_psi_init, p1_start, p2_start, p3_start, target_p3, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length, post_bond_angle, max_iter=20, tol=0.0001, tikhonov_eps=1e-06)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static backward(ctx, grad_output)[source]

IFT-based backward pass for the implicit closure constraint.

Given F(phi_psi*, theta) = 0, the IFT gives:

d(phi_psi*)/d(theta) = -[dF/d(phi_psi)]^+ @ dF/d(theta)

where ^+ denotes the pseudoinverse.

The VJP is:
grad_theta = grad_output^T @ d(phi_psi*)/d(theta)

= -lam_3d^T @ dF/d(theta)

where lam_3d = (J_phi @ J_phi^T)^{-1} @ J_phi @ grad_output (adjoint).

class torchref.base.chain_closure.closure.JunctionSolver(n_junctions, junction_size, initial_phi_psi, max_iter=20, tol=0.0001, tikhonov_eps=1e-06, dtype=None, device=None)[source]

Bases: DeviceMixin, Module

Manages warm-start buffers and solves junction closures.

Parameters:
  • n_junctions (int) – Number of junctions.

  • junction_size (int) – Number of residues per junction.

  • initial_phi_psi (torch.Tensor) – Initial guess for phi/psi, shape (J, 2*K).

  • max_iter (int) – Maximum Newton iterations.

  • tol (float) – Convergence tolerance.

__init__(n_junctions, junction_size, initial_phi_psi, max_iter=20, tol=0.0001, tikhonov_eps=1e-06, dtype=None, device=None)[source]
forward(p1_start, p2_start, p3_start, target_p3, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length=None, post_bond_angle=None)[source]

Solve junction closures and return backbone positions.

Returns:

  • phi_psi (torch.Tensor) – Solved phi/psi, shape (J, 2*K).

  • backbone_xyz (torch.Tensor) – Junction backbone positions, shape (J, 3*K, 3).

Return type:

Tuple[Tensor, Tensor]

property closure_residuals: Tensor | None