torchref.base.targets package
Pure-math loss kernels for refinement targets.
Each module in this package provides a single math function that mirrors the
tensor pipeline inside the corresponding torchref.refinement.targets
target, with all model/restraints/scaler/MixedTensor bookkeeping stripped
out. The signatures take only tensors and scalars — this is the boundary a
Triton kernel would replace.
Mapping
refinement.targets.geometry.BondTarget -> bond.bond_math refinement.targets.geometry.AngleTarget -> angle.angle_math refinement.targets.geometry.ChiralTarget -> chiral.chiral_math refinement.targets.geometry.PlanarityTarget -> planarity.planarity_math refinement.targets.geometry.TorsionTarget -> torsion.torsion_unimodal_math, torsion_omega_math refinement.targets.geometry.RamachandranTarget -> ramachandran.ramachandran_math refinement.targets.geometry.NonBondedTarget -> nonbonded.nonbonded_heavy_math refinement.targets.adp.ADPSimilarityTarget -> adp.adp_simu_math refinement.targets.adp.ADPEntropyTarget -> adp.adp_kl_math refinement.targets.adp.ADPLocalityTarget -> adp.adp_locality_math refinement.targets.xray.MaximumLikelihoodXrayTarget -> xray_ml.ml_xray_loss_math refinement.targets.xray.GaussianXrayTarget -> xray_gaussian.gaussian_xray_loss_math refinement.targets.xray.LeastSquaresXrayTarget -> xray_ls.ls_xray_loss_math refinement.targets.xray.BhattacharyyaXrayTarget -> xray_bhattacharyya.bhattacharyya_xray_loss_math
- torchref.base.targets.adp_kl_math(log_adp, target_log_std=0.2)[source]
KL divergence regularizer on log(B).
Mirrors
Model.adp_kl_divergence_loss: KL between an empirical Gaussian (mean fixed to currentmean(log_adp)detached, std =std(log_adp)) and a target Gaussian with the same mean but fixed std.
- torchref.base.targets.adp_locality_math(b, neighbor_indices, neighbor_distances)[source]
ADP locality NLL: weighted MSE on log(B) differences with KNN.
Mirrors
ADPLocalityTarget.forward. Neighbor list construction is the target’s bookkeeping and is not included here.
- torchref.base.targets.adp_simu_math(b, pair_indices, simu_sigma)[source]
ADP similarity (SIMU) NLL on bonded-atom B-factor differences.
Dispatches to
torchref.base.targets.triton.adp_simu_math_triton()on CUDA float32 (~1.6× faster fwd+bw on A100). Falls back to eager otherwise.- Parameters:
b (torch.Tensor) – (N_atoms,) B-factors.
pair_indices (torch.Tensor) – (N, 2) bonded-atom pairs to compare.
simu_sigma (torch.Tensor) – Scalar sigma on the difference (a buffer in the target).
- torchref.base.targets.angle_math(xyz, idx, references_rad, sigmas_rad)[source]
Angle NLL: gather, compute angle, Gaussian NLL, sum.
Dispatches to
torchref.base.targets.triton.angle_math_triton()on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.- Parameters:
xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.
idx (torch.Tensor) – (N, 3) integer indices [a, b, c] with
bas the vertex.references_rad (torch.Tensor) – (N,) target angles in radians.
sigmas_rad (torch.Tensor) – (N,) standard deviations in radians.
- torchref.base.targets.bhattacharyya_xray_loss_math(F_obs, F_calc, sigma_d, sigma_m, mask)[source]
Bhattacharyya overlap loss between data and model Gaussians.
- L_h = (F_obs - |F_calc|)^2 / (4 * (sigma_d^2 + sigma_m^2))
0.5 * log((sigma_d^2 + sigma_m^2) / (2 * sigma_d * sigma_m))
Dispatches to
torchref.base.targets.triton.xray_bhattacharyya.bhattacharyya_xray_loss_math_triton()on CUDA float32; falls back to the eager implementation otherwise.
- torchref.base.targets.bond_math(xyz, idx, references, sigmas)[source]
Bond NLL: gather, distance, Gaussian NLL, sum.
Mirrors
BondTarget.forward(geometry/bond) including the bond-length computation fromRestraints.bond_lengths.On CUDA float32 inputs this dispatches to
torchref.base.targets.triton.bond_math_triton()(~2.5× faster fwd+bw on A100). All other inputs use the eager implementation.- Parameters:
xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.
idx (torch.Tensor) – (N, 2) integer indices into
xyz.references (torch.Tensor) – (N,) target bond lengths.
sigmas (torch.Tensor) – (N,) standard deviations.
- torchref.base.targets.chiral_math(xyz, indices, ideal_volumes, sigmas)[source]
Chiral volume NLL (matches
ChiralTarget.forward).For each chiral center, computes the signed tetrahedral volume
V = v1 . (v2 x v3)wherevi = xyz[i] - xyz[center]. Achiral centers (ideal_volumes == 0) are restrained on|V|against 2.5.Dispatches to
torchref.base.targets.triton.chiral_math_triton()on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.- Parameters:
xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.
indices (torch.Tensor) – (N, 4) integer indices
[center, a1, a2, a3].ideal_volumes (torch.Tensor) – (N,) target signed volumes.
sigmas (torch.Tensor) – (N,) standard deviations.
- torchref.base.targets.gaussian_xray_loss_math(F_obs, F_calc, sigma, mask)[source]
Gaussian NLL on already-scaled amplitudes.
Dispatches to
torchref.base.targets.triton.xray_gaussian.gaussian_xray_loss_math_triton()on CUDA float32 inputs; falls back to the eager implementation otherwise.
- torchref.base.targets.ls_xray_loss_math(F_obs, F_calc, sigma, mask, weighting='sigma')[source]
Weighted least-squares loss on already-scaled amplitudes.
Dispatches to
torchref.base.targets.triton.xray_ls.ls_xray_loss_math_triton()on CUDA float32; falls back to the eager implementation otherwise.
- torchref.base.targets.ml_xray_loss_math(F_obs, F_calc, sigma, centric_flags, mask)[source]
Maximum-likelihood X-ray loss on already-scaled amplitudes.
Matches
MaximumLikelihoodXrayTarget.forwardlines 37-84.Dispatches to
torchref.base.targets.triton.xray_ml.ml_xray_loss_math_triton()on CUDA float32; falls back to the eager implementation otherwise.- Parameters:
F_obs (torch.Tensor) – (N,) observed amplitudes (zeros outside
mask).F_calc (torch.Tensor) – (N,) scaled calculated amplitudes (already real-valued, zeros outside
mask).sigma (torch.Tensor) – (N,) per-reflection sigma (ones outside
mask).centric_flags (torch.Tensor or None) – (N,) bool, True for centric reflections.
Noneis treated as all-acentric.mask (torch.Tensor) – (N,) bool work-set mask applied to the final sum.
- torchref.base.targets.nonbonded_heavy_math(xyz, indices, min_distances, symop_indices, cell_offsets, symop_matrices, symop_translations, fractional_matrix, inv_fractional_matrix, c_rep, r_exp, buffer, sigma_vdw)[source]
Heavy-heavy VDW prolsq repulsion NLL.
Matches the prolsq branch of
NonBondedTarget.forwardand the symmetry-aware pair-position gather fromNonBondedTarget._compute_positions. The H-VDW contribution added byNonBondedHTargetis excluded — seenonbonded_h(TBD) for that.Dispatches to
torchref.base.targets.triton.nonbonded_heavy_math_triton()on CUDA float32 (~1.4× faster fwd+bw on A100 — forward kernel is memory-bound but the analytic backward saves most of the win). Falls back to eager otherwise.- Parameters:
xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates of the ASU.
indices (torch.Tensor) – (N, 2) per-pair atom indices.
min_distances (torch.Tensor) – (N,) VDW threshold per pair.
symop_indices (torch.Tensor, optional) – (N,) symmetry-operator index per pair; 0 = identity.
cell_offsets (torch.Tensor, optional) – (N, 3) fractional cell offsets per pair.
symop_matrices (torch.Tensor, optional) – (n_symops, 3, 3) and (n_symops, 3) — the symmetry operator table.
symop_translations (torch.Tensor, optional) – (n_symops, 3, 3) and (n_symops, 3) — the symmetry operator table.
fractional_matrix (torch.Tensor) –
cell.fractional_matrixand its inverse (3, 3).inv_fractional_matrix (torch.Tensor) –
cell.fractional_matrixand its inverse (3, 3).c_rep (torch.Tensor) – Scalar repulsion coefficient, exponent, and effective tolerance.
r_exp (torch.Tensor) – Scalar repulsion coefficient, exponent, and effective tolerance.
sigma_vdw (torch.Tensor) – Scalar repulsion coefficient, exponent, and effective tolerance.
buffer (float) – Distance buffer in Å.
- torchref.base.targets.planarity_math(xyz, plane_groups)[source]
Planarity NLL summed over plane-size buckets.
Matches
PlanarityTarget.forward. Each entry ofplane_groupsis(indices, sigmas)whereindiceshas shape(P, n_atoms)withn_atoms > 3(3-atom planes have zero deviation by construction and are skipped by the caller).The plane normal is computed in float64 via SVD and detached — gradients flow through the deviation projection but not through the eigendecomposition.
Dispatches to
torchref.base.targets.triton.planarity.planarity_math_triton()on CUDA fp32 (per-bucket fused gather/centroid/project/NLL kernel + analytic backward). Falls back to eager otherwise.
- torchref.base.targets.ramachandran_math(xyz, phi_idx, psi_idx, nll_surfaces, surface_type)[source]
Ramachandran bilinear-interpolated NLL.
Mirrors
RamachandranTarget.forward.Dispatches to
torchref.base.targets.triton.ramachandran_math_triton()on CUDA float32 (~10× faster fwd+bw on A100). Falls back to eager otherwise.- Parameters:
xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.
phi_idx (torch.Tensor) – (N, 4) atom indices for the two backbone dihedrals.
psi_idx (torch.Tensor) – (N, 4) atom indices for the two backbone dihedrals.
nll_surfaces (torch.Tensor) – (n_surface_types, 360, 360) precomputed NLL = -log P(φ, ψ | type).
surface_type (torch.Tensor) – (N,) integer type per residue.
- torchref.base.targets.torsion_omega_math(xyz, idx, sigmas_deg, is_proline, w_cis_proline=0.05, w_cis_general=0.0005)[source]
Omega cis/trans mixture NLL.
Mirrors
targets.geometry.torsions._omega_mixture_nllplus the omega-angle computation. Models each ω as a 2-component mixture:w_trans VM(ω; π, κ) + w_cis VM(ω; 0, κ).Dispatches to
torchref.base.targets.triton.torsion_omega_math_triton()on CUDA float32 (~5.6× faster fwd+bw on A100). Falls back to eager otherwise.
- torchref.base.targets.torsion_unimodal_math(deviations_rad, sigmas_deg)[source]
Unimodal von Mises NLL on already-wrapped deviations.
Mirrors
targets.geometry.torsions._von_mises_nll. The caller is expected to compute wrapped angular deviations beforehand (currently done insideRestraints.torsion_deviations_with_sigmas).No Triton dispatch yet — the periodic-wrap logic upstream is still in eager Python.
Subpackages
- torchref.base.targets.triton package
adp_simu_math_triton()angle_math_triton()bhattacharyya_xray_loss_math_triton()bond_math_triton()chiral_math_triton()gaussian_xray_loss_math_triton()ls_xray_loss_math_triton()ml_xray_loss_math_triton()nonbonded_heavy_math_triton()planarity_math_triton()ramachandran_math_triton()torsion_omega_math_triton()- Submodules
- torchref.base.targets.triton.adp_simu module
- torchref.base.targets.triton.angle module
- torchref.base.targets.triton.bond module
- torchref.base.targets.triton.chiral module
- torchref.base.targets.triton.nonbonded module
- torchref.base.targets.triton.place_hydrogens module
- torchref.base.targets.triton.planarity module
- torchref.base.targets.triton.ramachandran module
- torchref.base.targets.triton.torsion module
- torchref.base.targets.triton.xray_bhattacharyya module
- torchref.base.targets.triton.xray_gaussian module
- torchref.base.targets.triton.xray_ls module
- torchref.base.targets.triton.xray_ml module
Submodules
- torchref.base.targets.adp module
- torchref.base.targets.angle module
- torchref.base.targets.bond module
- torchref.base.targets.chiral module
- torchref.base.targets.nonbonded module
- torchref.base.targets.planarity module
- torchref.base.targets.ramachandran module
- torchref.base.targets.torsion module
- torchref.base.targets.xray_bhattacharyya module
- torchref.base.targets.xray_gaussian module
- torchref.base.targets.xray_ls module
- torchref.base.targets.xray_ml module