torchref.base.targets package

Pure-math loss kernels for refinement targets.

Each module in this package provides a single math function that mirrors the tensor pipeline inside the corresponding torchref.refinement.targets target, with all model/restraints/scaler/MixedTensor bookkeeping stripped out. The signatures take only tensors and scalars — this is the boundary a Triton kernel would replace.

Mapping

refinement.targets.geometry.BondTarget -> bond.bond_math refinement.targets.geometry.AngleTarget -> angle.angle_math refinement.targets.geometry.ChiralTarget -> chiral.chiral_math refinement.targets.geometry.PlanarityTarget -> planarity.planarity_math refinement.targets.geometry.TorsionTarget -> torsion.torsion_unimodal_math, torsion_omega_math refinement.targets.geometry.RamachandranTarget -> ramachandran.ramachandran_math refinement.targets.geometry.NonBondedTarget -> nonbonded.nonbonded_heavy_math refinement.targets.adp.ADPSimilarityTarget -> adp.adp_simu_math refinement.targets.adp.ADPEntropyTarget -> adp.adp_kl_math refinement.targets.adp.ADPLocalityTarget -> adp.adp_locality_math refinement.targets.xray.MaximumLikelihoodXrayTarget -> xray_ml.ml_xray_loss_math refinement.targets.xray.GaussianXrayTarget -> xray_gaussian.gaussian_xray_loss_math refinement.targets.xray.LeastSquaresXrayTarget -> xray_ls.ls_xray_loss_math refinement.targets.xray.BhattacharyyaXrayTarget -> xray_bhattacharyya.bhattacharyya_xray_loss_math

torchref.base.targets.adp_kl_math(log_adp, target_log_std=0.2)[source]

KL divergence regularizer on log(B).

Mirrors Model.adp_kl_divergence_loss: KL between an empirical Gaussian (mean fixed to current mean(log_adp) detached, std = std(log_adp)) and a target Gaussian with the same mean but fixed std.

torchref.base.targets.adp_locality_math(b, neighbor_indices, neighbor_distances)[source]

ADP locality NLL: weighted MSE on log(B) differences with KNN.

Mirrors ADPLocalityTarget.forward. Neighbor list construction is the target’s bookkeeping and is not included here.

torchref.base.targets.adp_simu_math(b, pair_indices, simu_sigma)[source]

ADP similarity (SIMU) NLL on bonded-atom B-factor differences.

Dispatches to torchref.base.targets.triton.adp_simu_math_triton() on CUDA float32 (~1.6× faster fwd+bw on A100). Falls back to eager otherwise.

Parameters:
  • b (torch.Tensor) – (N_atoms,) B-factors.

  • pair_indices (torch.Tensor) – (N, 2) bonded-atom pairs to compare.

  • simu_sigma (torch.Tensor) – Scalar sigma on the difference (a buffer in the target).

torchref.base.targets.angle_math(xyz, idx, references_rad, sigmas_rad)[source]

Angle NLL: gather, compute angle, Gaussian NLL, sum.

Dispatches to torchref.base.targets.triton.angle_math_triton() on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.

Parameters:
  • xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.

  • idx (torch.Tensor) – (N, 3) integer indices [a, b, c] with b as the vertex.

  • references_rad (torch.Tensor) – (N,) target angles in radians.

  • sigmas_rad (torch.Tensor) – (N,) standard deviations in radians.

torchref.base.targets.bhattacharyya_xray_loss_math(F_obs, F_calc, sigma_d, sigma_m, mask)[source]

Bhattacharyya overlap loss between data and model Gaussians.

L_h = (F_obs - |F_calc|)^2 / (4 * (sigma_d^2 + sigma_m^2))
  • 0.5 * log((sigma_d^2 + sigma_m^2) / (2 * sigma_d * sigma_m))

Dispatches to torchref.base.targets.triton.xray_bhattacharyya.bhattacharyya_xray_loss_math_triton() on CUDA float32; falls back to the eager implementation otherwise.

torchref.base.targets.bond_math(xyz, idx, references, sigmas)[source]

Bond NLL: gather, distance, Gaussian NLL, sum.

Mirrors BondTarget.forward (geometry/bond) including the bond-length computation from Restraints.bond_lengths.

On CUDA float32 inputs this dispatches to torchref.base.targets.triton.bond_math_triton() (~2.5× faster fwd+bw on A100). All other inputs use the eager implementation.

Parameters:
torchref.base.targets.chiral_math(xyz, indices, ideal_volumes, sigmas)[source]

Chiral volume NLL (matches ChiralTarget.forward).

For each chiral center, computes the signed tetrahedral volume V = v1 . (v2 x v3) where vi = xyz[i] - xyz[center]. Achiral centers (ideal_volumes == 0) are restrained on |V| against 2.5.

Dispatches to torchref.base.targets.triton.chiral_math_triton() on CUDA float32 (~4× faster fwd+bw on A100). Falls back to eager otherwise.

Parameters:
  • xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.

  • indices (torch.Tensor) – (N, 4) integer indices [center, a1, a2, a3].

  • ideal_volumes (torch.Tensor) – (N,) target signed volumes.

  • sigmas (torch.Tensor) – (N,) standard deviations.

torchref.base.targets.gaussian_xray_loss_math(F_obs, F_calc, sigma, mask)[source]

Gaussian NLL on already-scaled amplitudes.

Dispatches to torchref.base.targets.triton.xray_gaussian.gaussian_xray_loss_math_triton() on CUDA float32 inputs; falls back to the eager implementation otherwise.

torchref.base.targets.ls_xray_loss_math(F_obs, F_calc, sigma, mask, weighting='sigma')[source]

Weighted least-squares loss on already-scaled amplitudes.

Dispatches to torchref.base.targets.triton.xray_ls.ls_xray_loss_math_triton() on CUDA float32; falls back to the eager implementation otherwise.

torchref.base.targets.ml_xray_loss_math(F_obs, F_calc, sigma, centric_flags, mask)[source]

Maximum-likelihood X-ray loss on already-scaled amplitudes.

Matches MaximumLikelihoodXrayTarget.forward lines 37-84.

Dispatches to torchref.base.targets.triton.xray_ml.ml_xray_loss_math_triton() on CUDA float32; falls back to the eager implementation otherwise.

Parameters:
  • F_obs (torch.Tensor) – (N,) observed amplitudes (zeros outside mask).

  • F_calc (torch.Tensor) – (N,) scaled calculated amplitudes (already real-valued, zeros outside mask).

  • sigma (torch.Tensor) – (N,) per-reflection sigma (ones outside mask).

  • centric_flags (torch.Tensor or None) – (N,) bool, True for centric reflections. None is treated as all-acentric.

  • mask (torch.Tensor) – (N,) bool work-set mask applied to the final sum.

torchref.base.targets.nonbonded_heavy_math(xyz, indices, min_distances, symop_indices, cell_offsets, symop_matrices, symop_translations, fractional_matrix, inv_fractional_matrix, c_rep, r_exp, buffer, sigma_vdw)[source]

Heavy-heavy VDW prolsq repulsion NLL.

Matches the prolsq branch of NonBondedTarget.forward and the symmetry-aware pair-position gather from NonBondedTarget._compute_positions. The H-VDW contribution added by NonBondedHTarget is excluded — see nonbonded_h (TBD) for that.

Dispatches to torchref.base.targets.triton.nonbonded_heavy_math_triton() on CUDA float32 (~1.4× faster fwd+bw on A100 — forward kernel is memory-bound but the analytic backward saves most of the win). Falls back to eager otherwise.

Parameters:
  • xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates of the ASU.

  • indices (torch.Tensor) – (N, 2) per-pair atom indices.

  • min_distances (torch.Tensor) – (N,) VDW threshold per pair.

  • symop_indices (torch.Tensor, optional) – (N,) symmetry-operator index per pair; 0 = identity.

  • cell_offsets (torch.Tensor, optional) – (N, 3) fractional cell offsets per pair.

  • symop_matrices (torch.Tensor, optional) – (n_symops, 3, 3) and (n_symops, 3) — the symmetry operator table.

  • symop_translations (torch.Tensor, optional) – (n_symops, 3, 3) and (n_symops, 3) — the symmetry operator table.

  • fractional_matrix (torch.Tensor) – cell.fractional_matrix and its inverse (3, 3).

  • inv_fractional_matrix (torch.Tensor) – cell.fractional_matrix and its inverse (3, 3).

  • c_rep (torch.Tensor) – Scalar repulsion coefficient, exponent, and effective tolerance.

  • r_exp (torch.Tensor) – Scalar repulsion coefficient, exponent, and effective tolerance.

  • sigma_vdw (torch.Tensor) – Scalar repulsion coefficient, exponent, and effective tolerance.

  • buffer (float) – Distance buffer in Å.

torchref.base.targets.planarity_math(xyz, plane_groups)[source]

Planarity NLL summed over plane-size buckets.

Matches PlanarityTarget.forward. Each entry of plane_groups is (indices, sigmas) where indices has shape (P, n_atoms) with n_atoms > 3 (3-atom planes have zero deviation by construction and are skipped by the caller).

The plane normal is computed in float64 via SVD and detached — gradients flow through the deviation projection but not through the eigendecomposition.

Dispatches to torchref.base.targets.triton.planarity.planarity_math_triton() on CUDA fp32 (per-bucket fused gather/centroid/project/NLL kernel + analytic backward). Falls back to eager otherwise.

torchref.base.targets.ramachandran_math(xyz, phi_idx, psi_idx, nll_surfaces, surface_type)[source]

Ramachandran bilinear-interpolated NLL.

Mirrors RamachandranTarget.forward.

Dispatches to torchref.base.targets.triton.ramachandran_math_triton() on CUDA float32 (~10× faster fwd+bw on A100). Falls back to eager otherwise.

Parameters:
  • xyz (torch.Tensor) – (N_atoms, 3) Cartesian coordinates.

  • phi_idx (torch.Tensor) – (N, 4) atom indices for the two backbone dihedrals.

  • psi_idx (torch.Tensor) – (N, 4) atom indices for the two backbone dihedrals.

  • nll_surfaces (torch.Tensor) – (n_surface_types, 360, 360) precomputed NLL = -log P(φ, ψ | type).

  • surface_type (torch.Tensor) – (N,) integer type per residue.

torchref.base.targets.torsion_omega_math(xyz, idx, sigmas_deg, is_proline, w_cis_proline=0.05, w_cis_general=0.0005)[source]

Omega cis/trans mixture NLL.

Mirrors targets.geometry.torsions._omega_mixture_nll plus the omega-angle computation. Models each ω as a 2-component mixture: w_trans VM(ω; π, κ) + w_cis VM(ω; 0, κ).

Dispatches to torchref.base.targets.triton.torsion_omega_math_triton() on CUDA float32 (~5.6× faster fwd+bw on A100). Falls back to eager otherwise.

torchref.base.targets.torsion_unimodal_math(deviations_rad, sigmas_deg)[source]

Unimodal von Mises NLL on already-wrapped deviations.

Mirrors targets.geometry.torsions._von_mises_nll. The caller is expected to compute wrapped angular deviations beforehand (currently done inside Restraints.torsion_deviations_with_sigmas).

No Triton dispatch yet — the periodic-wrap logic upstream is still in eager Python.

Subpackages

Submodules