torchref.base.chain_closure.closure module
Core chain closure mathematics for junction residues.
Uses the standard NeRF (Natural Extension Reference Frame) formula for backbone forward kinematics. Slave DOFs are solved by Newton’s method with IFT-based backward gradients.
NeRF torsion assignments for placing backbone atoms: - Place N(i): torsion = psi(i-1) = dihedral(N(i-1), CA(i-1), C(i-1), N(i)) - Place CA(i): torsion = omega(i) = dihedral(CA(i-1), C(i-1), N(i), CA(i)) - Place C(i): torsion = phi(i) = dihedral(C(i-1), N(i), CA(i), C(i))
For K junction residues, the slave DOFs are phi(0..K-1) and psi(0..K-1) (2K total). omega values are fixed. psi(-1) (= psi of the last pre-junction residue) is also fixed. The residual matches the computed end-of-junction position against the known post-junction position.
- torchref.base.chain_closure.closure.backbone_fk_junction(p1_start, p2_start, p3_start, phi_psi, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length=None, post_bond_angle=None)[source]
Forward kinematics through junction backbone using NeRF.
Optionally extends the FK by one atom (N_post) using psi(K-1), which enables the closure to target the post-junction N position dynamically.
- Parameters:
p1_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).
p2_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).
p3_start (torch.Tensor) – N, CA, C of last pre-junction residue, each shape (3,) or (J, 3).
phi_psi (torch.Tensor) – Slave DOFs: [phi_0, psi_0, phi_1, psi_1, …], shape (2*K,) or (J, 2*K) where K = junction_size.
nerf_bond_lengths (torch.Tensor) – Bond lengths in NeRF order: [C_prev-N_0, N_0-CA_0, CA_0-C_0, …], shape (3*K,) or (J, 3*K).
nerf_bond_angles (torch.Tensor) – Bond angles at pivot atoms in NeRF order: [angle_at_C_prev_for_N, angle_at_N_for_CA, angle_at_CA_for_C, …], shape (3*K,) or (J, 3*K).
omega (torch.Tensor) – Fixed omega angles, shape (K,) or (J, K).
psi_prev (torch.Tensor) – Fixed psi of pre-junction residue (used to place first N), shape () or (J,).
post_bond_length (torch.Tensor, optional) – C(K-1) -> N_post bond length, shape (J,). If provided together with post_bond_angle, the FK extends one more atom using psi(K-1).
post_bond_angle (torch.Tensor, optional) – Angle at C(K-1) for N_post placement, shape (J,).
- Returns:
end_p1, end_p2, end_p3 (torch.Tensor) – N, CA, C of last junction residue.
backbone_xyz (torch.Tensor) – All junction backbone positions, shape (3*K, 3) or (J, 3*K, 3).
end_point (torch.Tensor) – Closure target point: N_post if post params given, else C(K-1). Shape (3,) or (J, 3).
- Return type:
- torchref.base.chain_closure.closure.closure_residual(end_p3, target_p3)[source]
Compute the 3D position closure residual.
- Parameters:
end_p3 (torch.Tensor) – Computed C position of last junction residue, shape (3,) or (J, 3).
target_p3 (torch.Tensor) – Target C position, shape (3,) or (J, 3).
- Returns:
3D position residual, shape (3,) or (J, 3).
- Return type:
- class torchref.base.chain_closure.closure.JunctionClosure(*args, **kwargs)[source]
Bases:
FunctionCustom autograd for chain closure using the Implicit Function Theorem.
Forward: Newton solve for junction phi/psi. Backward: IFT adjoint for exact gradients.
- static forward(ctx, phi_psi_init, p1_start, p2_start, p3_start, target_p3, nerf_bond_lengths, nerf_bond_angles, omega, psi_prev, post_bond_length, post_bond_angle, max_iter=20, tol=0.0001, tikhonov_eps=1e-06)[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See Combined or separate forward() and setup_context() for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()staticmethod to handle setting up thectxobject.outputis the output of the forward,inputsare a Tuple of inputs to the forward.See Extending torch.autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()if they are intended to be used inbackward(equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.
- static backward(ctx, grad_output)[source]
IFT-based backward pass for the implicit closure constraint.
- Given F(phi_psi*, theta) = 0, the IFT gives:
d(phi_psi*)/d(theta) = -[dF/d(phi_psi)]^+ @ dF/d(theta)
where ^+ denotes the pseudoinverse.
- The VJP is:
- grad_theta = grad_output^T @ d(phi_psi*)/d(theta)
= -lam_3d^T @ dF/d(theta)
where lam_3d = (J_phi @ J_phi^T)^{-1} @ J_phi @ grad_output (adjoint).
- class torchref.base.chain_closure.closure.JunctionSolver(n_junctions, junction_size, initial_phi_psi, max_iter=20, tol=0.0001, tikhonov_eps=1e-06, dtype=None, device=None)[source]
Bases:
DeviceMixin,ModuleManages warm-start buffers and solves junction closures.
- Parameters:
n_junctions (int) – Number of junctions.
junction_size (int) – Number of residues per junction.
initial_phi_psi (torch.Tensor) – Initial guess for phi/psi, shape (J, 2*K).
max_iter (int) – Maximum Newton iterations.
tol (float) – Convergence tolerance.
- __init__(n_junctions, junction_size, initial_phi_psi, max_iter=20, tol=0.0001, tikhonov_eps=1e-06, dtype=None, device=None)[source]