torchref package
TorchRef - A PyTorch-based crystallographic refinement library.
TorchRef provides GPU-accelerated crystallographic structure refinement using PyTorch’s automatic differentiation and nn.Module architecture.
Key Features
Native PyTorch integration with nn.Module architecture
Automatic differentiation for custom target functions
GPU acceleration for structure factor calculations
Modular design for easy extension
Quick Start
from torchref import Refinement, ReflectionData, Model
# Load data and model
data = ReflectionData().load_mtz('data.mtz')
model = Model().load_pdb('structure.pdb')
# Run refinement
refinement = Refinement(data=data, model=model, device='cuda')
refinement.run_refinement(macro_cycles=10)
Modules
- io
File I/O for MTZ, PDB, CIF formats.
- model
Atomic structure models (coordinates, B-factors, occupancies).
- refinement
Core refinement framework with targets and weighting schemes.
- restraints
Geometry restraints (bonds, angles, torsions, planes). (initialized lazily as it requires downloading the monomer library)
- scaling
Structure factor scaling and bulk solvent models.
- symmetry
Crystallographic symmetry operations.
- alignment
Patterson-based structure alignment.
- math_functions
Mathematical utilities for crystallography.
- utils
General utilities and debugging tools.
- class torchref.ReflectionData(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)[source]
Bases:
CrystalDataset,DebugMixinContainer for crystallographic reflection data.
This class handles loading, processing, and accessing reflection data including Miller indices, structure factor amplitudes, intensities, and R-free flags. All data is stored as PyTorch tensors for GPU acceleration.
- Parameters:
- hkl
Miller indices of shape (N, 3), dtype int32.
- Type:
- F
Structure factor amplitudes of shape (N,), dtype float32.
- Type:
- F_sigma
Amplitude uncertainties of shape (N,), dtype float32.
- Type:
- I
Intensities of shape (N,), dtype float32.
- Type:
- I_sigma
Intensity uncertainties of shape (N,), dtype float32.
- Type:
- rfree_flags
R-free test set flags of shape (N,), dtype bool.
- Type:
- cell
Unit cell parameters [a, b, c, alpha, beta, gamma].
- Type:
- resolution
Resolution per reflection in Ångströms of shape (N,).
- Type:
Examples
Load reflection data from an MTZ file:
data = ReflectionData(verbose=1, device='cuda') data.load_mtz('data.mtz') print(f"Loaded {len(data.hkl)} reflections") print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
- source: ReflectionData | None = None
- __post_init__()[source]
Initialize non-dataclass attributes after dataclass init.
This is called automatically after the dataclass __init__.
- load(reader)[source]
Load reflection data using a data reader.
- Parameters:
reader (callable) – Data reader object that returns (data_dict, cell, spacegroup) when called. Can be MTZ, ReflectionCIFReader, or other compatible reader.
- Returns:
Self, for method chaining.
- Return type:
- Raises:
ValueError – If unit cell parameters are missing or no amplitude/intensity data found.
- classmethod from_tensors(hkl, F, F_sigma, cell, spacegroup, rfree_flags=None, device=device(type='cpu'), verbose=1)[source]
Construct ReflectionData directly from tensors.
- Parameters:
hkl (torch.Tensor) – Miller indices of shape (N, 3).
F (torch.Tensor) – Structure factor amplitudes of shape (N,).
F_sigma (torch.Tensor) – Amplitude uncertainties of shape (N,).
cell (Cell) – Unit cell parameters.
spacegroup (SpaceGroup) – Space group.
rfree_flags (torch.Tensor, optional) – R-free flags of shape (N,), dtype bool. If None, flags are generated automatically (2% free fraction).
device (str, optional) – Device for tensors. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
- Returns:
Fully initialized reflection data with all cleanup applied.
- Return type:
- load_mtz(path, column_names=None)[source]
Load reflection data from MTZ file.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- load_cif(path, data_block=None)[source]
Load reflection data from CIF file.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- static list_cif_data_blocks(path)[source]
List all data blocks available in a CIF file without loading data.
Useful for multi-dataset CIF files to inspect available blocks before loading a specific one.
- Parameters:
path (str) – Path to CIF file.
- Returns:
Names of all data blocks in the CIF file.
- Return type:
Examples
List and load a specific data block:
blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif') print(blocks) # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...] data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
- get_bins(n_bins=20, min_per_bin=100)[source]
Create resolution bins with approximately equal reflection counts.
- Parameters:
- Returns:
bin_indices (torch.Tensor) – Tensor of shape (N,) with bin index for each reflection.
n_bins (int) – Actual number of bins created (may be less than target for small datasets).
- Return type:
- mean_res_per_bin()[source]
Calculate mean resolution for each bin.
- Returns:
Mean resolution for each bin in Ångströms.
- Return type:
- Raises:
ValueError – If bins have not been created yet.
- mean_F_per_bin()[source]
Calculate mean structure factor amplitude per resolution bin.
- Returns:
Mean F per bin of shape (n_bins,).
- Return type:
- Raises:
ValueError – If bins have not been created yet.
- mean_sigma_per_bin()[source]
Calculate mean structure factor uncertainty per resolution bin.
- Returns:
Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.
- Return type:
torch.Tensor or None
- Raises:
ValueError – If bins have not been created yet.
- regenerate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100, seed=None, force=False)[source]
Regenerate R-free flags with resolution-stratified sampling.
- Parameters:
free_fraction (float, optional) – Fraction of reflections to mark as free. Default is 0.02 (2%).
n_bins (int, optional) – Target number of resolution bins. Default is 20.
min_per_bin (int, optional) – Minimum reflections per resolution bin. Default is 100.
seed (int, optional) – Random seed for reproducibility. Default is None.
force (bool, optional) – If True, regenerate even if flags already exist. Default is False.
Examples
Generate R-free flags:
# Generate 2% free reflections with reproducible seed data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42) # Generate 5% free with 10 bins, overwriting existing data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
- get_structure_factors(as_complex=False)[source]
Get structure factors, optionally as complex numbers.
- Parameters:
as_complex (bool, optional) – If True and phases available, return F*exp(i*phi). Default is False.
- Returns:
Structure factor amplitudes or complex structure factors.
- Return type:
- Raises:
ValueError – If no amplitude data is loaded.
- get_structure_factors_with_sigma()[source]
Get structure factor amplitudes and their uncertainties.
- Returns:
F (torch.Tensor) – Structure factor amplitudes of shape (N,).
F_sigma (torch.Tensor or None) – Uncertainties of shape (N,), or None if not available.
- Raises:
ValueError – If no amplitude data is loaded.
- Return type:
Examples
Get amplitudes with uncertainties:
F, sigma_F = data.get_structure_factors_with_sigma() if sigma_F is not None: weighted_residual = (F_obs - F_calc) / sigma_F
- get_hkl()[source]
Return Miller indices for valid reflections.
- Returns:
Miller indices of shape (N, 3), dtype int32.
- Return type:
- Raises:
ValueError – If no Miller indices are loaded.
- filter_by_resolution(d_min=None, d_max=None)[source]
Filter reflections by resolution range.
Adds a boolean mask to self.masks for the specified resolution range.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- get_mask()[source]
Return combined mask from all active filters.
- Returns:
Boolean mask combining all filter conditions.
- Return type:
- cut_res(highres=None, lowres=None)[source]
Filter reflections by resolution range.
Alias for filter_by_resolution with more intuitive naming.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
Examples
Filter by resolution:
# Keep reflections between 50 Å and 1.5 Å filtered = data.cut_res(highres=1.5, lowres=50.0) # Keep only high-resolution data (< 2 Å) high_res = data.cut_res(highres=1.0, lowres=2.0)
- get_rfree_masks()[source]
Get boolean masks for work and test (free) sets.
- Returns:
work_mask (torch.Tensor or None) – Boolean tensor for work set (flag != 0).
test_mask (torch.Tensor or None) – Boolean tensor for test/free set (flag == 0). Both are None if no R-free flags are available.
- Return type:
Examples
Separate work and test sets:
work_mask, test_mask = data.get_rfree_masks() if work_mask is not None: F_work = data.F[work_mask] F_test = data.F[test_mask]
- get_work_set()[source]
Get structure factors for the work set (R-free flag != 0).
- Returns:
F_work (torch.Tensor) – Structure factors for work set.
sigma_work (torch.Tensor or None) – Uncertainties for work set, or None if not available.
- Return type:
Notes
Returns full dataset with warning if no R-free flags available.
- get_test_set()[source]
Get structure factors for the test set (R-free flag == 0).
- Returns:
F_test (torch.Tensor) – Structure factors for test/free set.
sigma_test (torch.Tensor or None) – Uncertainties for test set, or None if not available.
- Raises:
ValueError – If no R-free flags are available.
- Return type:
- get_max_res()[source]
Return maximum resolution (lowest d-spacing).
- Returns:
Maximum resolution in Ångströms.
- Return type:
- get_min_res()[source]
Return minimum resolution (highest d-spacing).
- Returns:
Minimum resolution in Ångströms.
- Return type:
- __len__()[source]
Return number of reflections.
- Returns:
Number of reflections in the dataset.
- Return type:
- property d_min: float | None
Return maximum resolution (lowest d-spacing).
- Returns:
Maximum resolution in Ångströms.
- Return type:
- __repr__()[source]
Return string representation.
- Returns:
Summary of reflection data including count, sources, and resolution.
- Return type:
- get_valid_mask()[source]
Return the combined validity mask for all reflections.
This is the mask used to filter reflections in forward(). True indicates a valid (included) reflection, False indicates an excluded one.
- Returns:
Boolean mask of shape (N,) where N is the total number of reflections. True = valid/included, False = invalid/excluded.
- Return type:
Examples
Check validity:
mask = data.get_valid_mask() print(f"{mask.sum()} of {len(mask)} reflections are valid")
- data_indexed()[source]
Return reflection data as indexed (filtered) tensors.
This method filters out invalid reflections and returns smaller tensors containing only valid data. Useful for operations that don’t support MaskedTensors or for writing output files.
- Returns:
hkl (torch.Tensor) – Miller indices of shape (M, 3) where M is number of valid reflections.
F (torch.Tensor) – Structure factor amplitudes of shape (M,).
F_sigma (torch.Tensor or None) – Uncertainties of shape (M,) or None.
rfree_flags (torch.Tensor or None) – R-free flags of shape (M,) or None.
- Return type:
See also
forwardMain method returning MaskedTensors.
Examples
Get indexed data for file writing:
hkl, F, sigma, rfree = data.data_indexed() F_np = F.cpu().numpy() # Safe for writing to files
- __call__(mask=True, scale=True)[source]
Return core reflection data with MaskedTensors for F and sigma.
F and F_sigma are returned as MaskedTensors which keep all reflections but mark invalid ones as masked. Aggregation operations (sum, mean, etc.) automatically skip masked values. HKL and rfree_flags remain regular tensors.
- Parameters:
- Returns:
hkl (torch.Tensor) – Miller indices of shape (N, 3). Full size, unfiltered.
F (MaskedTensor) – Structure factor amplitudes of shape (N,) with invalid reflections masked.
F_sigma (MaskedTensor or None) – Uncertainties of shape (N,) with invalid reflections masked, or None.
rfree_flags (torch.Tensor or None) – R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.
- Return type:
Notes
MaskedTensors:
Are PyTorch tensors with an associated boolean mask
Aggregations (sum, mean, etc.) skip masked values automatically
Element-wise operations preserve the mask
Use .get_data() and .get_mask() to access underlying data
Use .to_tensor(fill_value) to convert back to regular tensor
Note: MaskedTensor is in prototype stage in PyTorch
Loss functions and targets extract valid data from MaskedTensors before computation to work correctly with complex F_calc values.
Examples
Access reflection data with MaskedTensors:
hkl, F, sigma, rfree = data() print(F.shape) # Full shape print(F.sum()) # Only sums valid (unmasked) values # Access underlying data valid_mask = F.get_mask() F_values = F.get_data()[valid_mask]
- data_fill_masked(mode='mean')[source]
Return data tensors with missing or flagged reflections filled with.
- args:
- mode: str, optional
‘mean’ : fill missing/flagged with mean of present data ‘zero’ : fill missing/flagged with zero
- __getitem__(key)[source]
Index into the reflection dataset.
- Parameters:
key (torch.Tensor) – Boolean mask or integer indices for selection.
- Returns:
New ReflectionData object with selected reflections.
- Return type:
- __select__(indices, op=None)[source]
Select reflections by boolean mask or integer indices.
Iterates over all dataclass fields generically, so new tensor fields are handled automatically.
- Parameters:
indices (torch.Tensor) – Boolean mask of shape (N,) or integer indices for selection.
op (str, optional) – Operation name for tracking purposes.
- Returns:
New ReflectionData object with selected reflections.
- Return type:
- sanitize_F()[source]
Remove invalid values from structure factors.
Adds a mask to filter out NaN, Inf, and non-positive values from F and F_sigma.
- validate_hkl(hkl_ref)[source]
Expand dataset to match a reference HKL set.
Reorders and expands the current dataset to match the reference HKL set. Reflections present in the reference but missing from this dataset are filled with placeholder values and masked out. This ensures all datasets aligned to the same reference have identical shapes and can be processed together without data loss from intersection operations.
- Parameters:
hkl_ref (torch.Tensor) – Reference Miller indices tensor of shape (N, 3), dtype int32. This defines the canonical HKL ordering for all aligned datasets.
- Returns:
Self, modified in-place with expanded arrays matching hkl_ref.
- Return type:
Notes
After calling this method: - self.hkl will equal hkl_ref exactly - All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded - Missing reflections are filled with 0 (or appropriate defaults) - A mask ‘hkl_present’ is added marking which reflections have real data - forward() will return MaskedTensors that skip missing reflections
This approach avoids the problem where intersecting many datasets with different outliers/missing reflections causes exponential data loss.
Examples
Align multiple datasets to a common HKL set:
reference_hkl = data1.hkl.clone() data1.validate_hkl(reference_hkl) data2.validate_hkl(reference_hkl) # Now data1 and data2 have identical shapes assert data1.hkl.shape == data2.hkl.shape
- find_outliers(model, scaler, z_threshold=4.0)[source]
Identify outlier reflections based on log-ratio distribution.
Uses the fact that log(F_obs) - log(F_calc) should be normally distributed. Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.
- Parameters:
- Returns:
Boolean mask where True indicates outliers.
- Return type:
- get_log_ratio(model, scaler)[source]
Compute log-ratio between observed and calculated structure factors.
- Parameters:
- Returns:
Log-ratio values: log(F_obs) - log(F_calc).
- Return type:
- get_outlier_statistics()[source]
Get statistics about flagged outliers.
- Returns:
Dictionary containing: - n_outliers : int - n_total : int - fraction_outliers : float - detection_params : dict or None - outlier_resolution_stats : dict (if resolution available)
- Return type:
- unpack_one()[source]
Unpack one level of source.
Does not recurse fully and does not flag.
- Returns:
Parent source or self if no source.
- Return type:
- flag_suspicious_sigma(z_threshold=5.0)[source]
Flag sigma values that deviate significantly from expected distribution.
Sigma values from a detector should follow a log-normal distribution. Values with z-scores beyond threshold are flagged as suspicious.
- Parameters:
z_threshold (float, optional) – Z-score threshold for flagging outliers. Default is 5.0.
- dump()[source]
Dump all reflection data to console for debugging.
Prints type, shape, and device information for all attributes.
- write_mtz(fname, fcalc=None, model_ft=None)[source]
Write reflection data to MTZ file with optional map coefficients.
- Parameters:
fname (str) – Output MTZ filename.
fcalc (torch.Tensor, optional) – Complex calculated structure factors of shape (N,). If provided, computes phases and map coefficients.
model_ft (ModelFT, optional) – ModelFT object to compute fcalc if not provided.
Notes
- The MTZ file will contain canonical column names:
FP, SIGFP: Observed amplitudes and uncertainties
I, SIGI: Observed intensities and uncertainties (if available)
FreeR_flag: R-free test set flags
FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)
DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)
- Map coefficients are computed as:
2mFo-DFc: 2*Fo - Fc (filled to resolution limit)
mFo-DFc: Fo - Fc
Examples
Write MTZ with map coefficients:
data = ReflectionData().load_mtz('observed.mtz') model = Model().load_pdb('model.pdb') model_ft = ModelFT(model, data.cell, data.spacegroup) fcalc = model_ft.forward(data.hkl) data.write_mtz('output.mtz', fcalc=fcalc)
- property centric
Get boolean mask for centric reflections (full size, unfiltered).
Calculates it if not already present. Returns unfiltered centric flags matching the full HKL array size, consistent with how forward() returns full-size arrays when using MaskedTensors.
- Returns:
Boolean tensor of shape (N,) where N is total reflections. True indicates centric reflection, False indicates acentric.
- Return type:
torch.Tensor or None
- calc_patterson(grid_size=None, grid_sampling=1)[source]
Calculate Patterson map of the dataset.
The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw)) is computed via inverse FFT of F². Data is expanded to P1 symmetry using only observed reflections (no filling of missing data).
- Parameters:
- Returns:
Real-valued Patterson map of shape (Nx, Ny, Nz). Origin is at grid position [0, 0, 0].
- Return type:
- possible_hkl()[source]
Generate all possible HKL indices within the resolution limit.
- Returns:
Tensor of shape (M, 3) containing all possible Miller indices within the resolution limit defined by self.resolution.
- Return type:
- remap(new_hkl, index_mapping, phase_shifts=None, spacegroup=None, op_name='remap')[source]
Create new ReflectionData with remapped HKL set and data.
This is the core method for index-based transformations of reflection data. It handles remapping all tensor fields based on an index mapping, with support for missing reflections (indicated by -1 in index_mapping).
- Parameters:
new_hkl (torch.Tensor, shape (M, 3)) – New Miller indices.
index_mapping (torch.Tensor, shape (M,), dtype int64) – Maps new indices to original:
new[i] = old[index_mapping[i]]Values of -1 indicate missing reflections (filled with defaults).phase_shifts (torch.Tensor, optional, shape (M,)) – Phase offsets to apply (e.g., from symmetry translations).
spacegroup (str, int, gemmi.SpaceGroup, or None) – New spacegroup. If None, keeps original.
op_name (str) – Operation name for provenance tracking.
- Returns:
New object with remapped data. Missing reflections get: - 0.0 for F, I, phase, fom - 1.0 for F_sigma, I_sigma (conservative uncertainty) - True for masks[‘missing’]
- Return type:
- fill(d_min=None)[source]
Fill missing reflections within resolution limit.
Generates all possible reflections for the current spacegroup within the resolution limit, identifies which are missing, and creates a complete dataset. Missing reflections are filled with default values.
- Parameters:
d_min (float, optional) – High resolution limit in Angstroms. If None, uses the minimum resolution from the current dataset.
- Returns:
New ReflectionData with complete set of reflections. Missing reflections have: - F, I, phase, fom = 0.0 - F_sigma, I_sigma = 1.0 - masks[‘missing’] = True
- Return type:
Examples
Fill to completeness:
data = ReflectionData().load_mtz('data.mtz') data_filled = data.fill(d_min=2.0) print(f"Original: {len(data)}, Filled: {len(data_filled)}")
- expand_to_p1(include_friedel=True, remove_absences=True)[source]
Expand reflection data from asymmetric unit to P1.
Applies all symmetry operations from the current space group to generate all symmetry-equivalent reflections. Returns a NEW ReflectionData object with expanded reflections; does not modify self.
- Parameters:
- Returns:
New ReflectionData object with: - All symmetry-equivalent reflections - spacegroup set to P1 - Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags) - Recalculated resolution values - Provenance tracking via source and last_op attributes
- Return type:
Notes
The expansion handles tensor fields as follows:
hkl: Symmetry operations applied, Friedel mates added, duplicates removed
F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original
phase: Indexed + phase shift from translation component
resolution: Recalculated from expanded hkl + cell
bin_indices: Cleared (invalidated by expansion)
Examples
Expand to P1 for calculations:
# Load data in original space group data = ReflectionData().load_mtz('data.mtz') print(f"Original: {len(data)} reflections, {data.spacegroup}") # Expand to P1 data_p1 = data.expand_to_p1() print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}") # Without Friedel mates (for anomalous data) data_p1_anom = data.expand_to_p1(include_friedel=False)
- reduce_to_spacegroup(spacegroup, include_friedel=True, aggregation='mean')[source]
Reduce P1 reflection data to asymmetric unit of a target spacegroup.
This is the inverse of expand_to_p1(). Takes reflection data in P1 and merges symmetry-equivalent reflections into single ASU reflections using the specified aggregation function.
- Parameters:
spacegroup (str, int, or gemmi.SpaceGroup) – Target space group specification.
include_friedel (bool, default True) – If True, also merge Friedel mates when reducing.
aggregation (str, default 'mean') – Aggregation function for merging equivalent reflections: - ‘mean’: Average values (default, good for amplitudes) - ‘sum’: Sum values - ‘first’: Take first valid value (no averaging)
- Returns:
New ReflectionData with merged reflections in the target spacegroup.
- Return type:
Notes
Field handling during reduction:
F, I: Aggregated using specified function
F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for ‘mean’
phase: Aggregated via complex averaging (handles phase wrapping)
fom: Weighted by amplitude during phase averaging
rfree_flags: OR operation (True if any equivalent is True)
Examples
Reduce symmetry-expanded data:
# Expand to P1 for calculations, then reduce back data_p1 = data.expand_to_p1() # ... modify F_p1 ... data_merged = data_p1.reduce_to_spacegroup('P21') # Reduce with sum instead of mean data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
- canonicalize(include_friedel=True)[source]
Return new ReflectionData with HKL in standard CCP4 ASU form.
Remaps all Miller indices to the canonical CCP4 asymmetric unit representative using
gemmi.ReciprocalAsu, adjusts phases accordingly, and sorts reflections lexicographically by (h, k, l).- Parameters:
include_friedel (bool, default True) – Whether Friedel mates are considered equivalent.
- Returns:
New object with canonicalized, sorted Miller indices.
- Return type:
- get_scattering_vectors()[source]
Get scattering vectors (s-vectors) from hkl and cell.
- The s-vector for a reflection hkl is defined as:
s = B* @ hkl
where B* is the reciprocal basis matrix.
- Returns:
s_vectors – Reciprocal space vectors in Angstroms^-1, shape (N, 3).
- Return type:
- Raises:
ValueError – If hkl or cell is not available.
- get_radial_shells(n_shells=20, d_min=None, d_max=None)[source]
Create uniform radial shells in 1/d space for normalization.
This creates shells with uniform spacing in reciprocal space (1/d), different from get_bins() which creates equal-count bins.
- Parameters:
- Returns:
shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).
shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).
shell_indices (torch.Tensor) – Shell index for each reflection, shape (N,). Values -1 for out-of-range.
- Return type:
- fit_anisotropy(n_shells=20, d_min=None, d_max=None, n_iterations=100, verbose=None)[source]
Fit anisotropy correction parameters to minimize CV within shells.
Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell.
- Parameters:
n_shells (int) – Number of resolution shells for variance calculation.
d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.
d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.
n_iterations (int) – Number of optimization iterations.
verbose (bool, optional) – Print progress. If None, uses self.verbose.
- Returns:
U – Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,). Also stored in self.U_aniso.
- Return type:
- Raises:
ValueError – If no amplitude data is available.
- setup_anisotropy(U_aniso=None)[source]
Setup anisotropy correction parameters. :param U_aniso: Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).
If None, uses Initializes as zeros.
- apply_anisotropy_correction(U_aniso=None)[source]
Apply anisotropy correction to F² values.
- Parameters:
U_aniso (torch.Tensor, optional) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,). If None, uses self.U_aniso (must have called fit_anisotropy first).
- Returns:
F_corrected (torch.Tensor) – Anisotropy-corrected F values, shape (N,).
sigma_F_corrected (torch.Tensor) – Uncertainties of corrected F values, shape (N,).
- Raises:
ValueError – If no U parameters available and none provided.
- Return type:
- compute_e_values(n_shells=20, d_min=None, d_max=None, apply_anisotropy=True, fit_anisotropy=True, verbose=None)[source]
Compute E-values with optional anisotropy correction.
E-values are normalized structure factors where <E²> = 1 within each resolution shell. Anisotropy correction can be applied first to account for directional variation in diffraction.
- Parameters:
n_shells (int) – Number of resolution shells for normalization.
d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.
d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.
apply_anisotropy (bool) – If True, apply anisotropy correction before E-value calculation.
fit_anisotropy (bool) – If True and apply_anisotropy is True, fit anisotropy parameters. If False and apply_anisotropy is True, uses existing self.U_aniso.
verbose (bool, optional) – Print progress. If None, uses self.verbose.
- Returns:
E – E-values, shape (N,). Also stored in self.E. self.E_squared is also populated with E² values.
- Return type:
- Raises:
ValueError – If no amplitude data is available.
Examples
Compute E-values with automatic anisotropy correction:
data = ReflectionData().load_mtz('data.mtz') E = data.compute_e_values(n_shells=30) print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")
Compute E-values without anisotropy correction:
E = data.compute_e_values(apply_anisotropy=False)
- get_corrected_data()[source]
Get scaled amplitude and uncertainty tensors.
Applies the current scale factor (from self.log_scale) to F and F_sigma.
- Returns:
Scaled F and F_sigma tensors. If log_scale is None, returns original.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- parameters()[source]
Get list of learnable parameters for optimization.
- Returns:
iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.
- Return type:
iter[Parameter]
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)
- class torchref.DatasetCollection(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)[source]
Bases:
CrystalDatasetContainer for multiple related crystal datasets.
All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out.
- Parameters:
- hkl
Common HKL set for all datasets.
- Type:
Examples
from torchref.io import DatasetCollection, ReflectionData collection = DatasetCollection(device='cuda') native = ReflectionData().load_mtz('native.mtz') derivative = ReflectionData().load_mtz('derivative.mtz') collection.add_dataset('native', native, set_as_reference=True) collection.add_dataset('derivative', derivative) for name, dataset in collection: print(f"{name}: {len(dataset)} reflections") # Access by name native_F = collection['native'].F
- add_dataset(name, dataset, set_as_reference=False)[source]
Add a dataset to the collection.
- Parameters:
name (str) – Identifier for this dataset.
dataset (ReflectionData) – The dataset to add.
set_as_reference (bool, optional) – If True, use this dataset’s HKL as the reference. Default is False, but the first dataset added automatically becomes the reference.
- Returns:
Self, for method chaining.
- Return type:
- Raises:
ValueError – If a dataset with the same name already exists.
Examples
collection = DatasetCollection() collection.add_dataset('native', native_data, set_as_reference=True) collection.add_dataset('derivative', derivative_data)
- property datasets: Dict[str, ReflectionData]
Access all datasets as a dictionary.
- __getitem__(name)[source]
Get dataset by name.
- __iter__()[source]
Iterate over (name, dataset) pairs in order of addition.
- Yields:
tuple of (str, ReflectionData) – Name and dataset for each dataset in collection.
- scale()[source]
Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps.
The collection instance, allowing for method chaining.
- Raises:
ValueError – If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh…
Notes
The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. “”” Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset.
- Returns:
for method chaining.
- Return type:
self
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)
- class torchref.Model(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]
Bases:
DeviceMixin,DebugMixin,ModuleBase model class for atomic structure models using PyTorch.
This class provides the foundation for managing atomic structure data including coordinates, atomic displacement parameters (ADPs), and occupancies. It supports both empty initialization for state_dict loading and file-based initialization from PDB/CIF files.
- Parameters:
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.
verbose (int, optional) – Verbosity level for logging. Default is 1.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.
- xyz
Atomic coordinates tensor with shape (n_atoms, 3).
- Type:
- adp
Atomic displacement parameters (isotropic B-factors) with shape (n_atoms,).
- Type:
- u
Anisotropic displacement parameters with shape (n_atoms, 6).
- Type:
- occupancy
Atomic occupancies with values in [0, 1].
- Type:
- pdb
DataFrame containing atomic model data.
- Type:
- spacegroup
Space group object.
- Type:
gemmi.SpaceGroup
Examples
Empty initialization for state_dict loading:
model = Model() model.load_state_dict(torch.load('model.pt'))
File-based initialization:
model = Model() model.load_pdb('structure.pdb')
- __init__(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]
Initialize an empty Model shell.
Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().
- Parameters:
dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.
verbose (int, optional) – Verbosity level for logging. Default is 1.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.
- property exclude_H_from_sf: bool
Whether to exclude hydrogen atoms from structure factor calculation.
When True, H atoms are excluded from
get_iso()/get_aniso()so they do not contribute to Fcalc. They still participate in geometry and VDW restraints. Default is False.
- property cell: Cell | None
Unit cell object with parameters [a, b, c, alpha, beta, gamma].
- Returns:
The unit cell object, or None if not set.
- Return type:
Cell or None
- property spacegroup: SpaceGroup | None
Space group object.
- Returns:
The space group object, or None if not set.
- Return type:
gemmi.SpaceGroup or None
- property symmetry: SpaceGroup | None
Symmetry operations handler for this space group.
Returns the same SpaceGroup object as self.spacegroup — the separate Symmetry wrapper was redundant since Symmetry is just an alias.
- Returns:
The space group object, or None if not set.
- Return type:
SpaceGroup or None
- property inv_fractional_matrix: Tensor
Fractionalization matrix B^-1 (Cartesian -> fractional).
Delegates to Cell for automatic caching and device/dtype handling.
- Returns:
Shape (3, 3) fractionalization matrix.
- Return type:
- property fractional_matrix: Tensor
Orthogonalization matrix B (fractional -> Cartesian).
Delegates to Cell for automatic caching and device/dtype handling.
- Returns:
Shape (3, 3) orthogonalization matrix.
- Return type:
- property recB: Tensor
Reciprocal basis matrix with [a*, b*, c*] as rows.
Delegates to Cell for automatic caching and device/dtype handling.
- Returns:
Shape (3, 3) matrix where rows are the reciprocal basis vectors.
- Return type:
- property Z: Tensor
Atomic numbers for all atoms.
- Returns:
Tensor of atomic numbers with shape (n_atoms,).
- Return type:
- get_P1_parameters_iso()[source]
Get model parameters transformed to P1 space for optimization.
This is useful for optimizers that do not handle symmetry directly or MD.
- Returns:
xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.
adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.
occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.
A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.
B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.
- Return type:
- get_MD_parameters()[source]
Get model parameters prepared for molecular dynamics simulation.
Returns all P1-expanded parameters plus atomic numbers for MD engines.
- Returns:
xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.
adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.
occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.
A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.
B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.
Z_p1 (torch.Tensor) – Atomic numbers expanded to P1 space.
- Return type:
- property parametrization
(A, B)}.
The parametrization is built lazily on first access.
- Returns:
Dictionary mapping element symbols to tuples of (A, B) tensors.
- Return type:
- Type:
ITC92 parametrization dictionary {element
- get_scattering_params_iso()[source]
Get ITC92 scattering parameters (A, B) for isotropic atoms.
- Returns:
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_iso_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_iso_atoms, 5).
- get_scattering_params_aniso()[source]
Get ITC92 scattering parameters (A, B) for anisotropic atoms.
- Returns:
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_aniso_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_aniso_atoms, 5).
- property restraints
Lazy restraints property.
The restraints are built on first access using the model’s pdb DataFrame and the CIF path set via set_restraints_cif().
- Returns:
The restraints object containing bond, angle, torsion, etc. restraints.
- Return type:
- bond_deviations()[source]
Compute bond length deviations using current xyz coordinates.
- Returns:
deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.
sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.
- angle_deviations()[source]
Compute angle deviations using current xyz coordinates.
- Returns:
deviations (torch.Tensor) – Calculated minus expected angles in radians.
sigmas (torch.Tensor) – Standard deviations in radians.
- torsion_deviations_with_sigmas()[source]
Compute torsion deviations (wrapped for periodicity) and sigmas.
- Returns:
deviations_rad (torch.Tensor) – Wrapped deviations in radians.
sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).
- property chain_sequences: List[Tuple[str, str]]
Per-chain amino acid sequences as single-letter codes.
Excludes HETATM records. Gaps in residue numbering are filled with
?. Non-standard residues are mapped toX.
- get_chain_residues()[source]
Per-chain residue names as 3-letter codes (for IHM/CIF writing).
Excludes HETATM records. Unlike
chain_sequences, returns the raw 3-letter codes without gap filling.
- get_vdw_radii()[source]
Get van der Waals radii for all atoms based on their elements.
Caches the result in self.vdw_radii for future calls.
- Returns:
Van der Waals radii for each atom with shape (n_atoms,).
- Return type:
- to(*args, **kwargs)[source]
Move Model and rebuild device-specific SF indices.
Delegates to
DeviceMixin, which walksself.__dict__(picking upself.cell,self.altloc_pairs,self._restraintsand all registered parameters / buffers), refreshes theself.devicetracker, and invalidates caches. Afterwards this override rebuilds the precomputed SF indices on the new device.
- copy()[source]
Create a deep copy of the Model.
Creates a complete independent copy including all registered buffers, module parameters, PDB DataFrame, and spacegroup information.
- Returns:
A new Model instance with copied data.
- Return type:
Examples
model = Model().load_pdb('structure.pdb') model_copy = model.copy() # model_copy is independent, changes won't affect model
- write_pdb(filename, metadata=None)[source]
Write model to PDB file with optional metadata header.
- Parameters:
filename (str) – Output PDB file path.
metadata (RefinementMetadata, optional) – Metadata to render as PDB header (REMARK 3, TITLE, etc.).
- write_cif(filename, metadata=None)[source]
Write model to mmCIF file with optional metadata.
- Parameters:
filename (str) – Output mmCIF file path.
metadata (RefinementMetadata, optional) – Metadata to include (refinement statistics, title, etc.).
- get_iso()[source]
Return per-atom parameters for the isotropic atom subset.
Selects atoms whose ADP is a single scalar
b(i.e. not anisotropic). The subset is defined by~self.aniso_flag— intersected withself._heavy_atom_maskwhen_exclude_H_from_sfis enabled — and is precomputed asself._iso_indicesat init / whenever the mask changes.- Returns:
xyz (torch.Tensor, shape
(n_iso, 3)) – Cartesian coordinates of the isotropic atoms (Å).adp (torch.Tensor, shape
(n_iso,)) – Isotropic B-factors (Ų).occupancy (torch.Tensor, shape
(n_iso,)) – Occupancies in[0, 1].
Notes
When every atom is isotropic and no H exclusion is active —
self._iso_covers_all is True, the common protein-refinement case — the per-atom indexing is skipped andself.xyz(),self.adp(),self.occupancy()are returned directly.Motivation:
self.xyz()[idx]is a no-op forward whenidx = arange(N), but its backward routes through PyTorch’saten::_index_put_impl_(accumulate=True), which performs acub::DeviceRadixSortOnesweepKerneloverlen(idx)indices followed by a deduplicated scatter (~50-150 µs/iter per gather on A100 / 1DAW). Skipping the gather avoids that cost.
- parameters_of_types(types)[source]
Return the leaf ``nn.Parameter``s for the named parameter types.
Used by refinement entry points (
refine_xyz,refine_adp, …) to construct an optimizer over only the leaves the caller intends to update.LossState.stepthen uses the optimizer’s param groups as intent and disablesrequires_gradon any other leaves the loss also touches.
- update_mask_from_selection(selection_string, target, mode='set', freeze=True)[source]
Update the refinable mask for a parameter using Phenix-style selection syntax.
This method updates the internal mask buffer (xyz_mask, adp_mask, u_mask, or occupancy_mask) based on the selection. The updated mask is NOT automatically applied to the parameter tensors - use apply_mask_to_parameter() to apply it.
- Parameters:
selection_string (str) – Phenix-style selection string (see parse_phenix_selection docs).
target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.
mode (str, optional) – How to combine with current mask: - ‘set’: Replace mask with selection (default) - ‘add’: Add selection to current mask - ‘remove’: Remove selection from current mask
freeze (bool, optional) – If True (default), selected atoms will be frozen (mask=False). If False, selected atoms will be unfrozen (mask=True).
- Raises:
ValueError – If target is not recognized or selection syntax is invalid.
Examples
# Freeze chain A coordinates model.update_mask_from_selection("chain A", "xyz", mode='set', freeze=True) model.apply_mask_to_parameter("xyz") # Unfreeze backbone atoms model.update_mask_from_selection("name CA or name C or name N", "xyz", freeze=False) model.apply_mask_to_parameter("xyz")
- apply_mask_to_parameter(target)[source]
Apply the current mask buffer to the parameter tensor.
Takes the current state of the mask buffer (xyz_mask, adp_mask, etc.) and applies it to the corresponding parameter tensor’s refinable mask.
- Parameters:
target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.
- Raises:
ValueError – If target is not recognized.
Examples
model.update_mask_from_selection("chain A", "xyz", freeze=True) model.apply_mask_to_parameter("xyz")
- freeze_selection(selection_string, targets='all')[source]
Freeze atoms matching a Phenix-style selection for specified parameters.
Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.
- Parameters:
Examples
# Freeze all parameters for chain A model.freeze_selection("chain A", targets='all') # Freeze only coordinates for residues 10-20 model.freeze_selection("resseq 10:20", targets='xyz')
- unfreeze_selection(selection_string, targets='all')[source]
Unfreeze atoms matching a Phenix-style selection for specified parameters.
Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.
- Parameters:
Examples
# Unfreeze all parameters for chain A model.unfreeze_selection("chain A", targets='all') # Unfreeze only coordinates for backbone atoms model.unfreeze_selection("name CA or name C or name N", targets='xyz')
- get_aniso()[source]
Return per-atom parameters for the anisotropic atom subset.
Selects atoms whose ADP is the 6-element anisotropic tensor
u = (u11, u22, u33, u12, u13, u23). The subset is defined byself.aniso_flag— intersected withself._heavy_atom_maskwhen_exclude_H_from_sfis enabled — and is precomputed asself._aniso_indicesat init / whenever the mask changes.- Returns:
xyz (torch.Tensor, shape
(n_aniso, 3)) – Cartesian coordinates of the anisotropic atoms (Å). Empty tensor when there are no anisotropic atoms.u (torch.Tensor, shape
(n_aniso, 6)) – Anisotropic U components (Ų) in the order(u11, u22, u33, u12, u13, u23). Empty whenn_aniso == 0.occupancy (torch.Tensor, shape
(n_aniso,)) – Occupancies in[0, 1]. Empty whenn_aniso == 0.
Notes
When there are no anisotropic atoms —
self._aniso_is_empty is True, the common protein-refinement case — three empty placeholder tensors are returned without calling the MixedTensors at all. This avoids both the wrapped forward.clone()and the slowaten::_index_put_impl_backward path that theself.xyz()[idx]gather would otherwise generate (seeget_iso()for the same rationale).
- parameters(recurse=True)[source]
Return an iterator over module parameters.
This is typically passed to an optimizer.
- Args:
- recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields:
Parameter: module parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- named_mixed_tensors()[source]
Iterate over all MixedTensor attributes with their names.
- Yields:
Tuple of (name, MixedTensor)
- register_alternative_conformations()[source]
Identify and register all alternative conformation groups in the structure.
For each residue that has alternative conformations (altloc A, B, C, etc.), this method identifies all atoms belonging to each conformation and stores their indices as tensors in a tuple.
The result is stored in self.altloc_pairs as a list of tuples, where each tuple contains tensors of atom indices for each alternative conformation.
Examples
For a residue with conformations A and B:
# Conformation A has atoms at indices [100, 101, 102, ...] # Conformation B has atoms at indices [110, 111, 112, ...] # Result: [(tensor([100, 101, 102, ...]), tensor([110, 111, 112, ...])), ...]
For a residue with conformations A, B, C:
# Result: [(tensor([200, 201, ...]), tensor([210, 211, ...]), tensor([220, 221, ...])), ...]
- shake_coords(stddev)[source]
Apply random Gaussian noise to atomic coordinates.
Perturbs the atomic coordinates by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.
- Parameters:
stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstroms.
- shake_adp(stddev)[source]
Apply random Gaussian noise to ADPs (atomic displacement parameters).
Perturbs the ADPs by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.
- Parameters:
stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstrom^2.
- generate_hydrogens(mon_lib_path=None)[source]
Generate hydrogen atoms for the current model using gemmi.
Places hydrogens at ideal geometry using the CCP4 monomer library and gemmi’s topology engine. Returns a new Model instance with hydrogens added; the original model is not modified.
- Parameters:
mon_lib_path (str, optional) – Path to CCP4 monomer library directory. If None, uses the monomer library bundled with torchref (covers standard amino acids and common small molecules).
- Returns:
A new Model instance with hydrogen atoms added (strip_H=False). Unknown residues are skipped silently.
- Return type:
Notes
Requires gemmi (already a torchref dependency). Heavy-atom coordinates from the current model state are used, so call this after any coordinate changes you want reflected in the H positions.
Examples
>>> model_no_h = Model().load_pdb('structure.pdb') >>> model_with_h = model_no_h.generate_hydrogens() >>> print(model_with_h.Z.shape) # more atoms than model_no_h
- strip_altlocs()[source]
Return a new model with alternate conformations removed.
For each residue that has multiple altlocs, the conformer with highest average occupancy is kept (ties broken alphabetically). The
altloccolumn is cleared to""in the returned model. The original model is not modified.
- strip_hydrogens()[source]
Return a new model with hydrogen atoms removed.
The returned model has consistent DataFrame and tensors (xyz, adp, occupancy) with H atoms excluded. The original model is not modified.
- Returns:
New model without hydrogen atoms.
- Return type:
- hydrogenate(verbose=0, optimize=False, lbfgs_steps=3, max_iter=20)[source]
Return a new model with hydrogen atoms placed via Kabsch alignment.
Uses torchref’s monomer library to identify missing H atoms, places them by SVD-aligning ideal monomer coordinates onto the current model coordinates, then corrects each H to sit at ideal bond length from its parent atom. The original model is not modified.
- Parameters:
verbose (int, optional) – Verbosity level (0=silent, 1=summary, 2=detailed). Default 0.
optimize (bool, optional) – If True, run a short LBFGS geometry optimization on H positions after placement. Default False (Kabsch placement only).
lbfgs_steps (int, optional) – Number of LBFGS outer steps (only when optimize=True). Default 3.
max_iter (int, optional) – Max line-search iterations per LBFGS step. Default 20.
- Returns:
New model with hydrogen atoms added. All parameters are unfrozen in the returned model.
- Return type:
- adp_loss()[source]
Compute the ADP regularization loss.
This loss encourages ADPs to have similar values across the structure, helping to prevent overfitting during refinement.
- Returns:
Scalar tensor representing the ADP loss.
- Return type:
- adp_nll_loss(target_log_std=0.2)[source]
Compute negative log-likelihood of ADPs assuming Gaussian distribution in log-space.
This regularization penalizes ADPs that deviate from a target distribution with a FIXED standard deviation (hyperparameter), avoiding circular dependency on the current distribution’s statistics.
The NLL for a Gaussian distribution in log-space is:
NLL = 0.5 * mean[(log_adp - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
Where mu is the mean of log-space ADPs (computed from current data) and sigma is the FIXED target standard deviation (hyperparameter).
- Parameters:
target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. - 0.1 = very tight (ADPs within ~10% of mean) - 0.2 = moderate spread (ADPs within ~20% of mean) [RECOMMENDED] - 0.3 = looser spread (ADPs within ~30% of mean)
- Returns:
Scalar tensor representing the NLL. Lower values indicate the distribution is closer to the target Gaussian with fixed sigma.
- Return type:
Examples
# During refinement structure_factor_loss = compute_structure_factor_loss() nll_reg = model.adp_nll_loss(target_log_std=0.2) total_loss = structure_factor_loss + 0.01 * nll_reg total_loss.backward()
Notes
Uses FIXED sigma (no circular dependency on current distribution). Smaller target_log_std = stronger regularization (tighter distribution).
- adp_nll_loss_per_atom(target_log_std=0.2)[source]
Compute per-atom negative log-likelihood for ADPs in log-space.
Returns the NLL contribution for each individual atom, useful for identifying outliers or applying atom-specific regularization weights.
The per-atom NLL is:
NLL_i = 0.5 * [(log_adp_i - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
- Parameters:
target_log_std (float, optional) – Fixed target standard deviation in log-space. Default is 0.2.
- Returns:
Tensor of shape (n_atoms,) with per-atom NLL values. Higher values indicate atoms farther from the mean.
- Return type:
Examples
# Get per-atom NLL atom_nll = model.adp_nll_loss_per_atom(target_log_std=0.2) # Identify outlier atoms (high NLL) threshold = atom_nll.mean() + 2 * atom_nll.std() outliers = atom_nll > threshold
- adp_kl_divergence_loss(target_log_std=0.2)[source]
Compute KL divergence between log ADP distribution and target Gaussian.
Measures how different the current log ADP distribution is from a target Gaussian distribution with the current mean of log ADPs and a fixed target standard deviation.
KL divergence formula for two Gaussians with same mean:
KL(q || p) = log(sigma_target/sigma_data) + sigma_data^2 / (2*sigma_target^2) - 0.5
- Parameters:
target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. Controls how tightly ADPs should cluster.
- Returns:
Scalar KL divergence value (always >= 0). 0 means distributions match perfectly. Higher values mean more deviation from target.
- Return type:
Examples
# Use in loss function loss = xray_loss + w_adp * model.adp_kl_divergence_loss(0.2)
Notes
Lower target_log_std = stronger regularization (tighter distribution). Mean is detached so it adapts to the natural scale of the data.
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the Model.
Includes all registered buffers, model parameters (xyz, b, u, occupancy), PDB DataFrame, and metadata (spacegroup, device, dtype, etc.).
- save_state(path)[source]
Save the complete state of the model to a file.
- Parameters:
path (str) – Path to save the state dictionary to.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]
Create a fully initialized Model from a state dictionary.
This is the recommended way to restore a Model from a saved state. Creates an instance with properly initialized submodules, then loads the state.
- Parameters:
state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
dtype_float (torch.dtype, optional) – Float dtype for tensors. Defaults to the configured dtypes.float.
- Returns:
Fully initialized instance with restored state.
- Return type:
- get_selection_mask(selection)[source]
Return a boolean mask for atoms matching a Phenix-style selection.
This is a convenience method that wraps parse_phenix_selection() to return a mask that can be used directly with MixedTensor.set() or other operations requiring atom selection.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
Boolean tensor of shape (n_atoms,) where True indicates selected atoms.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid.
Examples
model = Model().load_pdb('structure.pdb') # Get mask for chain A mask = model.get_selection_mask("chain A") # Use mask to update coordinates new_coords = model.xyz()[mask] + translation model.xyz.set(new_coords, mask) # Get mask for backbone atoms backbone_mask = model.get_selection_mask("name CA or name C or name N or name O") # Complex selection with parentheses mask = model.get_selection_mask("chain A and (resname ALA or resname GLY)")
- select(selection)[source]
Return a new Model containing only atoms matching the Phenix-style selection.
Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid or no atoms are selected.
Examples
model = Model().load_pdb('structure.pdb') # Select chain A chain_a = model.select("chain A") # Select backbone atoms backbone = model.select("name CA or name C or name N or name O") # Select residues 10-50 of chain B region = model.select("chain B and resseq 10:50") # Select all except water no_water = model.select("not resname HOH") # Complex selection with parentheses complex_sel = model.select("chain A and (resname ALA or resname GLY)")
Notes
This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.
- xyz_fractional()[source]
Return atomic coordinates in fractional space.
Converts Cartesian coordinates to fractional coordinates using the inverse fractional matrix.
- Returns:
Tensor of shape (n_atoms, 3) with fractional coordinates.
- Return type:
- rotate(rotation_matrix, center=None)[source]
Apply rotation to atomic coordinates (in-place).
Rotates all atoms around a specified center point. The rotation is applied using the formula: xyz_new = R @ (xyz - center) + center
- Parameters:
rotation_matrix (torch.Tensor) – 3x3 rotation matrix. Should be orthogonal (R^T @ R = I).
center (torch.Tensor, optional) – Center of rotation with shape (3,). If None, uses the centroid of all atomic coordinates.
- Returns:
Self, for method chaining.
- Return type:
Examples
# Rotate 90 degrees around Z-axis import math angle = math.pi / 2 R = torch.tensor([ [math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1] ]) model.rotate(R) # Rotate around a specific point center = torch.tensor([10.0, 20.0, 30.0]) model.rotate(R, center=center)
- translate(translation, fractional=False)[source]
Apply translation to atomic coordinates (in-place).
Translates all atoms by a specified vector. The translation can be given in either Cartesian or fractional coordinates.
- Parameters:
translation (torch.Tensor) – Translation vector with shape (3,).
fractional (bool, optional) – If True, the translation is interpreted as fractional coordinates and converted to Cartesian before applying. Default is False (translation is in Cartesian Angstroms).
- Returns:
Self, for method chaining.
- Return type:
Examples
# Translate by 5 Angstroms along X model.translate(torch.tensor([5.0, 0.0, 0.0])) # Translate by half a unit cell along each axis model.translate(torch.tensor([0.5, 0.5, 0.5]), fractional=True)
- get_centroid()[source]
Compute the centroid (center of mass) of all atoms.
- Returns:
Centroid coordinates with shape (3,).
- Return type:
- use_internal_coordinates(n_aa_per_segment=5, bond_cutoff=2.0, cif_dict=None, requires_grad=True)[source]
Switch xyz to segmented internal coordinate parametrization.
Replaces the current xyz MixedTensor with a SegmentedInternalCoordinateTensor that parametrizes atomic positions using bond lengths, angles, torsion angles, and per-segment rigid body parameters. The molecule is broken into independent segments to avoid the “lever arm problem” where small torsion changes near the root cause large displacements at distant atoms.
- Parameters:
n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 5. - Smaller values (1-2): More segments, shallower trees, less lever arm - Larger values (5-10): Fewer segments, deeper trees, more lever arm
bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0. Only used when cif_dict is not provided.
cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] DataFrame with ‘atom1’, ‘atom2’.
requires_grad (bool, optional) – Whether internal coordinate parameters should have gradients. Default is True.
- Returns:
Self, for method chaining.
- Return type:
Examples
model = Model() model.load_pdb('structure.pdb') model.use_internal_coordinates(n_aa_per_segment=3) # Now model.xyz() returns coordinates reconstructed from # segmented internal coordinates # Shake the structure using internal coordinates new_xyz = model.xyz.shake(magnitude=0.1) # Each segment has independent internal coordinates and # rigid body parameters (position + orientation)
Notes
After calling this method, model.xyz will be a SegmentedInternalCoordinateTensor instead of a MixedTensor. This provides: - Shallow spanning trees within segments (depth ~10-30 vs ~1000) - Independent segments that don’t propagate changes to distant atoms - Rigid body parameters (position + orientation) per segment - forward() / __call__(): Reconstruct Cartesian coordinates - shake(magnitude): Add noise to internal parameters - Gradient flow through all internal coordinate parameters
- class torchref.ModelFT(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]
Bases:
CachedForwardMixin,ModelModel subclass for Fourier Transform-based electron density and structure factor calculations.
ModelFT extends the base Model class with capabilities for computing electron density maps in real space and structure factors via FFT. Uses ITC92 parametrization for electron density calculations.
- Parameters:
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.
gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed from cell and max_res.
wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.
anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.
*args – Additional positional arguments passed to parent Model class.
**kwargs – Additional keyword arguments passed to parent Model class.
- gridsize
Grid dimensions (nx, ny, nz).
- Type:
- real_space_grid
Real-space coordinate grid with shape (nx, ny, nz, 3).
- Type:
- map
Computed electron density map.
- Type:
torch.Tensor or None
- map_symmetry
Symmetry operator for map calculations.
- Type:
Examples
Empty initialization for state_dict loading:
model = ModelFT() model.load_state_dict(torch.load('model.pt'))
File-based initialization:
model = ModelFT(max_res=1.5) model.load_pdb('structure.pdb')
- __init__(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]
Initialize an empty ModelFT shell.
Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().
- Parameters:
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 4.0.
gridsize (tuple of int, optional) – Explicit grid size tuple (nx, ny, nz). If None, computed automatically.
wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.
anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.
*args – Passed to parent Model class.
**kwargs – Passed to parent Model class.
- property cell
Unit cell object with parameters [a, b, c, alpha, beta, gamma].
- property spacegroup
Space group object.
- select(selection)[source]
Return a new Model containing only atoms matching the Phenix-style selection.
Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid or no atoms are selected.
Examples
model = Model().load_pdb('structure.pdb') # Select chain A chain_a = model.select("chain A") # Select backbone atoms backbone = model.select("name CA or name C or name N or name O") # Select residues 10-50 of chain B region = model.select("chain B and resseq 10:50") # Select all except water no_water = model.select("not resname HOH") # Complex selection with parentheses complex_sel = model.select("chain A and (resname ALA or resname GLY)")
Notes
This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.
- setup_gridsize(max_res=None)[source]
Compute optimal grid dimensions.
Delegates to FFT.compute_grid_size().
- Parameters:
max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.
- Returns:
Grid dimensions (nx, ny, nz) as int32 tensor.
- Return type:
- property A: Tensor
ITC92 A parameters (amplitudes) for all atoms.
- Returns:
A parameters with shape (n_atoms, 5).
- Return type:
- property B: Tensor
ITC92 B parameters (widths) for all atoms.
- Returns:
B parameters with shape (n_atoms, 5).
- Return type:
- get_iso()[source]
Get isotropic atoms with their ITC92 parameters.
- Returns:
xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).
adp (torch.Tensor) – Atomic displacement parameters (isotropic) with shape (n_atoms,).
occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- get_aniso()[source]
Get anisotropic atoms with their ITC92 parameters.
- Returns:
xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).
u (torch.Tensor) – Anisotropic U parameters with shape (n_atoms, 6).
occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- setup_grid(max_res=None, gridsize=None)[source]
Setup real-space grid for electron density calculation.
Delegates to FFT.setup_grid() using the stored cell and spacegroup.
- get_radius(min_radius_Angstrom=4.0)[source]
Get the radius in voxels used for density calculation around each atom.
- build_complete_map(radius=None, apply_symmetry=True)[source]
Build electron density map from all atoms.
Uses get_iso() and get_aniso() to get atom data and constructs the complete electron density map.
- Parameters:
- Returns:
Electron density map with symmetry applied if requested.
- Return type:
- build_initial_map(apply_symmetry=True)[source]
Build electron density map from atomic parameters.
Delegates to FFT.build_density_map() using the model’s stored parameters.
- Parameters:
apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.
- Returns:
Electron density map with shape (nx, ny, nz).
- Return type:
- save_map(filename)[source]
Save the electron density map to a CCP4 format file.
- Parameters:
filename (str) – Output filename for the map.
- Raises:
ValueError – If no map has been computed yet.
- rebuild_map(radius=None)[source]
Rebuild the density map from scratch.
Convenience method that clears and rebuilds everything.
- Parameters:
radius (int, optional) – Radius in voxels around each atom. If None, uses self.radius. If specified, overrides self.radius.
- Returns:
Rebuilt electron density map.
- Return type:
- get_structure_factor(hkl, recalc=False, apply_anomalous=True)[source]
Get structure factors for given hkl reflections.
Uses
CachedForwardMixinto cache the result and auto-invalidate when parameters change or a backward pass propagates through.- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, forces recalculation bypassing the cache. Default is False.
apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.
- Returns:
Complex structure factors with shape (n_reflections,).
- Return type:
Notes
- The complete scattering factor is:
f(s, λ) = f₀(s) + f’(λ) + i·f’’(λ)
where f₀ is the normal (Thomson) scattering factor computed via FFT, and f’/f’’ are the wavelength-dependent anomalous corrections.
Anomalous corrections are only computed for atoms where |f'| > anomalous_threshold or |f''| > anomalous_threshold.
- property fft
Access the SfFFT submodule.
- forward(hkl, apply_anomalous=True)[source]
Compute structure factors for given hkl.
This is called by the mixin’s
__call__which handles caching, backward-hook registration, and auto-invalidation.- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.
- Returns:
Calculated complex structure factors with shape (n_reflections,).
- Return type:
- copy(detach=True)[source]
Create a deep copy of the ModelFT.
Creates a complete independent copy including all Model base class data, FFT submodule state (gridsize, real_space_grid, voxel_size, map_symmetry), ITC92 parametrization, and scalar attributes. Cache is reset to empty.
- Parameters:
detach (bool, optional) – If True, the copy’s parameters will be detached from the computation graph (default: True).
- Returns:
A new ModelFT instance with copied data.
- Return type:
Examples
model = ModelFT().load_pdb('structure.pdb') model_copy = model.copy() # model_copy is independent, changes won't affect model
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the ModelFT.
Extends parent Model.state_dict() with FT-specific parameters including max_res, radius_angstrom. Grid state is handled by the FFT submodule.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]
Create a fully initialized ModelFT from a state dictionary.
This is the recommended way to restore a ModelFT from a saved state. Creates an instance with properly initialized submodules, then loads the state.
- Parameters:
state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
dtype_float (torch.dtype, optional) – Float dtype for tensors. Default is dtypes.float.
- Returns:
Fully initialized instance with restored state.
- Return type:
- class torchref.Refinement(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]
Bases:
DeviceMixin,DebugMixin,ModuleRefinement class to handle the overall crystallographic refinement process.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
refinement = Refinement() # Creates empty shell with submodules refinement.load_state_dict(torch.load('refinement.pt'))
Full initialization with file paths:
refinement = Refinement(data_file='data.mtz', pdb='model.pdb')
- Parameters:
data_file (str, optional) – Path to MTZ or CIF file containing reflection data.
pdb (str, optional) – Path to PDB or CIF file containing initial model.
cif (str, optional) – Path to CIF file for restraints.
verbose (int, optional) – Verbosity level. Default is 1.
max_res (float, optional) – Maximum resolution for reflections.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.
nbins (int, optional) – Number of resolution bins. Default is 10.
- device
Computation device.
- Type:
- reflection_data
Reflection data container.
- Type:
- weighter
Loss weighting module.
- Type:
LossWeightingModule
- __init__(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]
Initialize Refinement.
If data_file and pdb are provided, fully initializes the refinement. If not provided (empty init), creates a shell with empty submodules ready for load_state_dict().
- Parameters:
data_file (str, optional) – Path to MTZ or CIF file containing reflection data.
pdb (str, optional) – Path to PDB or CIF file containing initial model.
cif (str, optional) – Path to CIF file for restraints.
verbose (int, optional) – Verbosity level. Default is 1.
max_res (float, optional) – Maximum resolution for reflections.
device (torch.device, optional) – Computation device. Defaults to the configured device.current.
weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.
nbins (int, optional) – Number of resolution bins. Default is 10.
- set_xray_target_mode(mode)[source]
Change the X-ray target mode.
- Parameters:
mode (str) – X-ray target mode. Options are ‘gaussian’, ‘ls’, or ‘ml’.
- property data
Expose reflection_data as ‘data’ for weighting module compatibility.
- Returns:
The reflection data container.
- Return type:
- property loss_state: LossState
Get or create the persistent LossState.
The LossState is created once and reused across refinement cycles. Targets are registered once; weights are updated each cycle.
- Returns:
The persistent loss state with targets registered.
- Return type:
- property logger: Logger
Get or create the Logger for this refinement.
- Returns:
Logger instance linked to the persistent LossState.
- Return type:
- reset_loss_state()[source]
Reset the persistent LossState and Logger.
Call this if targets need to be re-registered (e.g., after changing target modes or reinitializing targets).
- parameters(recurse=True)[source]
Return unique parameters from this module and all submodules.
Uses the default Module.parameters() to gather parameters, then removes duplicates while preserving order to avoid passing the same tensor multiple times to the optimizer.
- adp_loss()[source]
Compute total ADP loss using TotalADPTarget.
This combines:
Bond-based similarity (SIMU-like)
Spread control (tighter than KL)
Bounds penalty
- Returns:
Total ADP loss value.
- Return type:
- nll_xray()[source]
Compute X-ray negative log-likelihood for work and test sets.
- Returns:
Tuple of (work_nll, test_nll) tensors.
- Return type:
- xray_loss_work()[source]
Compute X-ray loss on work set using instantiated target.
- Returns:
X-ray loss on work set.
- Return type:
- xray_loss_test()[source]
Compute X-ray loss on test set using instantiated target.
- Returns:
X-ray loss on test set.
- Return type:
- bond_loss()[source]
Compute bond length NLL via geometry_target.
- Returns:
Bond length NLL loss.
- Return type:
- torsion_loss()[source]
Compute torsion angle NLL via geometry_target.
- Returns:
Torsion angle NLL loss.
- Return type:
- geometry_loss()[source]
Compute total geometry NLL using TotalGeometryTarget.
- Returns:
Total geometry NLL loss.
- Return type:
- loss()[source]
Compute total loss using LossState pipeline.
Creates a LossState, populates meta, caches losses, updates weights, and returns the aggregated weighted loss.
- Returns:
Total weighted loss.
- Return type:
- setup_component_weighting()[source]
Set up component weighting with ResolutionWeighting + OverfittingWeighting.
- populate_state_meta(state)[source]
Populate LossState.meta with all model-level data.
Called once per macro cycle before weighting schemes are applied. This is the single location where refinement data is extracted into state.
- update_weights(state, multiply=False)[source]
Compute weights from component_weighting and update state. Weights are clipped to [0.01, 100.0] to avoid extreme values.
- create_loss_state()[source]
Create a configured LossState for optimization.
Deprecated since version Use: the loss_state property instead for the persistent state. This method is kept for backwards compatibility.
Sets up a LossState with all targets registered as callables with hierarchical naming (e.g., ‘geometry/bond’, ‘adp/simu’). Weights are applied from component_weighting.
- Usage:
from torchref.utils import validate_loss
state = refinement.create_loss_state() params = list(refinement.parameters())
# Log initial state state.aggregate(log_values=True)
# In an LBFGS closure, wrap with validate_loss so non-finite # losses warn + reject the step instead of poisoning the run. def closure():
optimizer.zero_grad() loss = state.aggregate() loss.backward() ok = validate_loss(
loss, state=state, parameters=params, context=”my_refinement”, raise_on_fail=False,
) if not ok:
- for p in params:
- if p.grad is not None:
p.grad.zero_()
return torch.full_like(loss.detach(), float(“inf”))
return loss
optimizer.step(closure)
# Log final state state.new_entry() state.aggregate(log_values=True)
- Returns:
Configured LossState with targets and weights.
- Return type:
- complete_loss_state()[source]
Update and return the persistent LossState.
Updates the persistent LossState with current meta, target info, cached losses, and weights. The state is reused across cycles.
The cached active-parameter leaf set is not refreshed here. Stale leaves are not a correctness hazard: a leaf that’s in the set but whose Parameter object was replaced externally (e.g. by
Model.freeze) just gets ignored by_freeze_graph_extras, which costs a marginal amount of wasted backward work but never produces wrong answers. If you do callModel.freeze/Model.unfreezebetween LossState creation and a refinement step, callstate.refresh_loss_leaves()explicitly.- Returns:
Complete LossState with targets, meta, losses, and weights.
- Return type:
- restraints_loss()[source]
Compute total geometry restraints loss.
- Returns:
Total geometry restraints loss.
- Return type:
- collect_metrics()[source]
Collect all metrics from component_weighting.stats().
This is the standard method for gathering refinement metrics for logging. Uses the centralized component_weighting module for all statistics. Returns full unfiltered stats - filtering is done at display time.
- Returns:
Dictionary with all metrics (unfiltered, with StatEntry objects).
- Return type:
- add_target_info_to_state(state)[source]
Add target information from geometry and ADP targets to LossState.meta.
Deprecated since version This: method is no longer needed. Use
complete_loss_state()instead, which handles all state setup in one call.
- collect_deposition_metadata(metadata=None)[source]
Collect refinement statistics into a RefinementMetadata object.
Reuses existing statistics from
collect_metrics(),get_rfactor(), and reflection data attributes.- Parameters:
metadata (RefinementMetadata, optional) – Existing metadata to merge with (e.g. from input file pass-through). Refinement statistics take precedence over pass-through values.
- Returns:
Metadata populated with final refinement statistics.
- Return type:
- write_out_pdb(out_pdb_path='refined_output.pdb', metadata=None)[source]
Write refined PDB with optional metadata header.
- Parameters:
out_pdb_path (str) – Output PDB file path.
metadata (RefinementMetadata, optional) – Metadata for PDB header. If None, auto-collected from refinement.
- write_out_cif(out_cif_path='refined_output.cif', metadata=None)[source]
Write refined coordinates as mmCIF with metadata.
- Parameters:
out_cif_path (str) – Output mmCIF file path.
metadata (RefinementMetadata, optional) – Metadata for mmCIF categories. If None, auto-collected from refinement.
- save_state(path)[source]
Save the complete state of the refinement to a file.
- Parameters:
path (str) – Path to save the state dictionary to.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1)[source]
Create a fully initialized Refinement from a state dictionary.
This is the recommended way to restore a Refinement from a saved state. It creates the proper submodules using their respective create_from_state_dict methods, then calls PyTorch’s default load_state_dict.
- Parameters:
state_dict (dict) – State dictionary from torch.save(refinement.state_dict(), …) or from loading a checkpoint file.
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
- Returns:
Fully initialized instance with restored state.
- Return type:
Examples
Save and load refinement state:
# Save torch.save(refinement.state_dict(), 'refinement.pt') # Load state = torch.load('refinement.pt') refinement = Refinement.create_from_state_dict(state) # Continue refinement rwork, rfree = refinement.get_rfactor() print(f"Restored at R-work={rwork:.4f}, R-free={rfree:.4f}")
- class torchref.LBFGSRefinement(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]
Bases:
RefinementLBFGS-based refinement subclass using the L-BFGS optimizer for fast convergence.
L-BFGS (Limited-memory BFGS) is a quasi-Newton optimization method that approximates the Hessian matrix, leading to much faster convergence than first-order methods.
Key advantages:
Converges in 1-2 macro cycles (vs 5+ for Adam)
Better final R-factors
More stable convergence
Automatically handles step size via line search
- Parameters:
target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, or ‘ml’). Default is ‘ml’.
*args – Passed to parent Refinement class.
**kwargs – Passed to parent Refinement class.
Examples
Basic usage:
from torchref.refinement import LBFGSRefinement refinement = LBFGSRefinement( data_file='data.mtz', pdb='model.pdb', target_mode='ml' ) refinement.refine(macro_cycles=2)
- LBFGS_DEFAULTS = {'history_size': 100, 'line_search_fn': 'strong_wolfe', 'lr': 1.0, 'max_iter': 20}
- __init__(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]
Initialize LBFGS refinement.
- Parameters:
target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, ‘ml’, ‘bhattacharyya’). Default is ‘bhattacharyya’.
sigma_m_scale (float, optional) – Global multiplier for σ_m in the Bhattacharyya target only. Ignored for other target modes. Default 1.0.
use_lossstate_scaler (bool, optional) – If True (default),
refine_scaler()uses the fullLossStatewith the body’s x-ray target — so scaler and body steps share one consistent loss. If False, falls back toScaler.refine_lbfgswhich minimises a standalonenll_xrayand can pull scales in a different direction than the body optimization.*args – Passed to parent Refinement class.
**kwargs – Passed to parent Refinement class.
- xray_loss()[source]
Compute X-ray loss using the instantiated target.
- Returns:
X-ray loss on work set.
- Return type:
- refine_scaler()[source]
Refine scaler parameters against the full refinement loss.
Builds the body
LossStateviacomplete_loss_state(), constructs a fresh LBFGS optimizer overlist(self.scaler.parameters()), and delegates toLossState.step(). Becausestate.stepdisablesrequires_gradon every loss leaf outside the optimizer’s intent set, xyz / adp / u / occupancy are pinned for the duration — only scaler parameters move.The critical property is that the x-ray target used here is the same one the body
refine_xyz()andrefine_adp()see. The legacyScaler.refine_lbfgs()minimises a standalonenll_xray+U^2penalty, which can pull scales in a different direction than abhattacharyyaormlbody loss and leaves the body to chase a scaler that disagrees with its own objective.When
use_lossstate_scaleris False, fall back to the legacyScaler.refine_lbfgs()path.- Returns:
LossStatewith history ifuse_lossstate_scaleris True, otherwise the metrics dict fromScaler.refine_lbfgs().- Return type:
- refine_xyz()[source]
Refine Cartesian coordinates jointly with scaler parameters.
Scaler parameters (
log_scale,U, solvent terms) are included in the same LBFGS call asxyz. The joint curvature lets xyz steps see the scaler as an anchor — residuals the scaler can absorb do not have to be chased by atomic motion — and theadp/scaler_Uandadp/scaler_log_scalepriors bite on every step, so nothing in the scaler drifts between refine_xyz and refine_adp calls.- Returns:
State with history containing before/after loss values.
- Return type:
- refine_adp()[source]
Refine ADP / U / occupancy jointly with scaler parameters.
Scaler parameters (
log_scale,U, solvent terms) are included in the same LBFGS call as the ADP-block body parameters so the joint curvature can slide along the atomic-B / scaler-U degeneracy ridge together with theadp/scaler_Uregularizer. XYZ is left frozen.- Returns:
State with history containing before/after loss values.
- Return type:
- refine_joint()[source]
Joint LBFGS over every refinable parameter in one step.
Optimizes
xyz,adp,u,occupancy, and every scaler parameter (log_scale, anisotropicU, solvent terms) in a single LBFGS call. The joint curvature couples all of them through the same x-ray target and through theadp/scaler_U/adp/scaler_log_scalepriors — unlike alternating refine_xyz → refine_adp, there’s no “frozen partner” in either half that could lock the step into a locally bad direction.- Returns:
State with history containing before/after loss values.
- Return type:
- run_training_trajectory(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]
Run a training trajectory with policy-guided refinement.
This method runs a sequence of refinement steps using a policy to select component weights. It records state-action-reward tuples for training the policy with AWR or similar algorithms.
- Parameters:
policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode with sampling).
n_steps (int, optional) – Number of refinement steps in the trajectory (default: 10).
pdb_id (str, optional) – PDB identifier for recording.
structure_path (str, optional) – Path to structure file for recording.
sf_path (str, optional) – Path to structure factors file for recording.
seed (int, optional) – Random seed for reproducibility.
policy_version (str, optional) – Version identifier of the policy being used.
- Returns:
Complete trajectory with state-action-reward tuples.
- Return type:
- run_training_trajectory_joint(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]
Run a training trajectory with joint XYZ+ADP refinement.
Similar to
run_training_trajectory()but refines xyz, adp, u, and occupancy together in each step. The LBFGS curvature history is reset at the start of each policy step because the weight updates invalidate any prior Hessian approximation.- Parameters:
policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode).
n_steps (int, optional) – Number of refinement steps (default: 10).
pdb_id (str, optional) – Identifiers for trajectory recording.
structure_path (str, optional) – Identifiers for trajectory recording.
sf_path (str, optional) – Identifiers for trajectory recording.
seed (int, optional) – Random seed for reproducibility.
policy_version (str, optional) – Policy version identifier.
- Returns:
Complete trajectory with state-action-reward tuples.
- Return type:
- class torchref.Scaler(model=None, data=None, nbins=20, verbose=1, device=None)[source]
Bases:
ScalerBaseFull-featured scaler with Model integration.
Extends ScalerBase by maintaining a reference to a Model object and providing convenience methods that automatically compute F_calc when not provided.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
scaler = Scaler() # Creates empty shell scaler.load_state_dict(torch.load('scaler.pt'))
Full initialization with model and data:
scaler = Scaler(model, reflection_data, nbins=20) scaler.initialize()
- Parameters:
model (Model, optional) – Model object for structure factor calculation.
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, default: configured device.current) – Computation device.
- device
Current computation device.
- Type:
- __init__(model=None, data=None, nbins=20, verbose=1, device=None)[source]
Initialize Scaler.
If model and data are provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
model (Model, optional) – Model object for structure factor calculation.
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, optional) – Computation device. If
None, derived frommodelthendata(model wins on mismatch); otherwise forces both onto the explicit device. Seetorchref.utils.resolve_device().
- property model
Access the model object (not a registered submodule).
- set_model_and_data(model, data)[source]
Set model and data references after empty initialization.
This is useful when loading from state_dict and then needing to reconnect to model/data objects.
- Parameters:
model (Model) – Model object for structure factor calculation.
data (ReflectionData) – ReflectionData object with observed data.
- initialize(fcalc=None)[source]
Initialize scaling parameters.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
- compute_fcalc()[source]
Compute F_calc from internal model.
- Returns:
Calculated structure factors.
- Return type:
- Raises:
RuntimeError – If no model is set.
- calc_initial_scale(fcalc=None)[source]
Calculate initial scale factors.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
- Returns:
The log scale parameter for each resolution bin.
- Return type:
torch.nn.Parameter
- fit_anisotropy(fcalc=None, nsteps=100)[source]
Fit anisotropic correction.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
nsteps (int, default 100) – Number of optimization steps.
- setup_solvent()[source]
Setup solvent model using internal model.
Creates a SolventModel using the internal model reference.
- fit_all_scales(fcalc=None)[source]
Fit all scale parameters.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
- screen_solvent_params(fcalc=None, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]
Screen solvent parameters using grid search.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
steps (int, default 15) – Number of grid points for each parameter.
use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily.
low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.
fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.
low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.
- refine_lbfgs(fcalc=None, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]
Refine scale parameters using LBFGS optimizer.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
nsteps (int, default 3) – Number of LBFGS steps.
lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).
max_iter (int, default 200) – Maximum iterations per line search.
history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.
verbose (bool, default True) – Print progress information.
- Returns:
Dictionary with refinement metrics.
- Return type:
- rfactor(fcalc=None)[source]
Calculate R-factors.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
- Returns:
R-work and R-free values.
- Return type:
- bin_wise_rfactor(fcalc=None)[source]
Calculate bin-wise R-factors.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
- Returns:
mean_res_per_bin (torch.Tensor) – Mean resolution per bin.
rwork_per_bin (torch.Tensor) – R-work per bin.
rfree_per_bin (torch.Tensor) – R-free per bin.
- get_binwise_mean_intensity(fcalc=None)[source]
Get bin-wise mean intensities.
If fcalc is not provided, computes it from the internal model.
- Parameters:
fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.
- Returns:
Mean observed intensity, mean calculated intensity, and mean resolution per bin.
- Return type:
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the Scaler.
This includes:
All registered buffers and parameters (via parent class)
Scaler-specific metadata (nbins, etc.)
Solvent model state (if initialized)
Note: Model and data references are NOT saved (managed separately).
- class torchref.ScalerBase(data=None, nbins=20, verbose=1, device=None)[source]
Bases:
DeviceMixin,DebugMixin,ModuleBase scaler class for crystallographic scaling without model dependency.
All methods that require calculated structure factors (F_calc) take them as input arguments. This allows the scaler to be used independently of any specific model implementation.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
scaler = ScalerBase() # Creates empty shell scaler.load_state_dict(torch.load('scaler.pt'))
Full initialization with data:
scaler = ScalerBase(data=reflection_data, nbins=20) scaler.initialize(fcalc)
- Parameters:
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, default: configured device.current) – Computation device.
- device
Current computation device.
- Type:
- __init__(data=None, nbins=20, verbose=1, device=None)[source]
Initialize ScalerBase.
If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
data (ReflectionData, optional) – ReflectionData object with observed data.
nbins (int, default 20) – Number of resolution bins.
verbose (int, default 1) – Verbosity level.
device (torch.device, optional) – Computation device. If
None, derived fromdata(if given) or the configured default viatorchref.utils.resolve_device(). An explicit value forcesdataonto that device.
- set_data(data)[source]
Set data reference after empty initialization.
This is useful when loading from state_dict and then needing to reconnect to a data object.
- Parameters:
data (ReflectionData) – ReflectionData object with observed data.
- initialize(fcalc)[source]
Initialize scaling parameters using provided F_calc.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- property hkl
Get HKL indices from data.
- calc_initial_scale(fcalc)[source]
Calculate the initial scale factor based on the ratio of observed to calculated structure factors.
Excludes reflections with negative intensities to avoid bias from French-Wilson conversion.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
The log scale parameter for each resolution bin.
- Return type:
torch.nn.Parameter
- anisotropy_correction()[source]
Compute anisotropic correction factors.
- Returns:
Anisotropic correction factors for each reflection.
- Return type:
- fit_anisotropy(fcalc, nsteps=100)[source]
Fit anisotropic correction using provided F_calc.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
nsteps (int, default 100) – Number of optimization steps.
- set_solvent_model(solvent_model)[source]
Set a pre-configured SolventModel for solvent contribution.
The SolventModel must be initialized externally (requires a Model object).
- Parameters:
solvent_model (SolventModel) – Pre-configured solvent model that can compute solvent structure factors.
- setup_binwise_solvent_scale()[source]
Setup bin-wise solvent scaling (Phenix-style kmask per bin).
This allows finer control over solvent contribution per resolution bin, which is more flexible than a single global B_sol parameter.
- fit_all_scales(fcalc)[source]
Fit all scale parameters using provided F_calc.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- fit_simple(fobs, fcalc)[source]
Fit a single global scale factor analytically (least-squares).
- This is the simple scaling approach:
k = sum(|F_obs||F_calc|) / sum(|F_calc|²)
Useful for rigid body refinement where only an overall scale is needed.
- Parameters:
fobs (torch.Tensor) – Observed structure factor amplitudes.
fcalc (torch.Tensor) – Calculated structure factors (complex).
- get_scale()[source]
Get the current overall scale factor value.
Returns the mean scale factor across all bins.
- Returns:
Current scale factor (not log).
- Return type:
- rfactor(fcalc)[source]
Calculate the R-factor between observed and calculated structure factors.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
R-work and R-free values.
- Return type:
- bin_wise_rfactor(fcalc)[source]
Calculate the bin-wise R-factor between observed and calculated structure factors.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
mean_res_per_bin (torch.Tensor) – Mean resolution per bin.
rwork_per_bin (torch.Tensor) – R-work per bin.
rfree_per_bin (torch.Tensor) – R-free per bin.
- bin_wise_bfactor_correction()[source]
Compute bin-wise B-factor correction factors.
- Returns:
B-factor correction factors for each reflection.
- Return type:
- get_binwise_mean_intensity(fcalc)[source]
Get bin-wise mean intensities for observed and calculated structure factors.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
- Returns:
Mean observed intensity, mean calculated intensity, and mean resolution per bin.
- Return type:
- screen_solvent_params(fcalc, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]
Screen solvent parameters (k_sol, B_sol) using grid search.
The bulk solvent contributes primarily at low resolution. Fitting on low-resolution reflections only (fit_on_low_res_only=True) prevents high-resolution reflections from dominating the optimization and pushing B_sol too low.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
steps (int, default 15) – Number of grid points for each parameter.
use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily since solvent primarily contributes at low resolution.
low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.
fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.
low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.
- refine_lbfgs(fcalc, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]
Refine scale parameters using LBFGS optimizer.
This method optimizes the anisotropic scaling and B-factor parameters that relate calculated structure factors to observed structure factors. Uses the L-BFGS quasi-Newton optimization method for fast convergence.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex).
nsteps (int, default 3) – Number of LBFGS steps.
lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).
max_iter (int, default 200) – Maximum iterations per line search.
history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.
verbose (bool, default True) – Print progress information.
- Returns:
Dictionary with refinement metrics including steps, xray_work, xray_test, rwork, rfree.
- Return type:
- estimate_sigma_eff(fcalc, max_inflation=2.0)[source]
Estimate per-resolution-shell effective sigmas from current residuals.
Pannu & Read / SIGMAA-style correction: detects miscalibrated experimental sigmas by comparing residual variance to the claimed variance, per resolution bin.
For each resolution bin:
D_bin = < (F_obs - k * |F_calc|)^2 > (using work set) ratio_bin = sqrt(D_bin / <sigma_F^2>) ratio_capped = clamp(ratio_bin, 1.0, max_inflation) sigma_eff = sigma_F * ratio_capped
Why the cap? At the start of refinement the model is bad, so residuals are dominated by model error (which is fixable by refining), not noise. Uncapped inflation creates a vicious cycle: bad model -> huge sigma_eff -> weak data gradient -> bad model. Capping at
max_inflation(default 2.0, i.e. sigmas can grow at most 2x) prevents runaway while still correcting genuinely under-estimated sigmas.As the model improves, residuals shrink and the ratio drops toward 1, so sigma_eff converges to the raw sigma (good calibration).
Uses the work set only so the test set doesn’t leak into sigma estimation.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors (complex, unscaled).
max_inflation (float, optional) – Maximum allowed ratio sigma_eff / sigma_raw. Default 2.0.
- Returns:
Per-reflection effective sigmas, shape (N,).
- Return type:
- forward(fcalc, use_mask=True, f_sol_override=None)[source]
Forward pass for the ScalerBase module.
- Parameters:
fcalc (torch.Tensor) – Calculated structure factors. Expected shape (N,), an additional dimension for batch is possible. N should match the full HKL size.
use_mask (bool, default True) – Deprecated parameter, kept for backward compatibility.
f_sol_override (torch.Tensor, optional) – Pre-computed raw solvent structure factors. When provided, these replace the internally-cached
_f_sol_raw. The scaler’s k_sol / B_sol / phase damping is still applied. This is used byCollectionScalerto supply mixed (fraction-weighted) solvent contributions.
- Returns:
Scaled structure factors of same shape as input.
- Return type:
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the ScalerBase.
This includes:
All registered buffers and parameters (via parent class)
Scaler-specific metadata (nbins, etc.)
Solvent model state (if set)
Note: Data reference is NOT saved (managed separately).
- load_state_dict(state_dict, strict=True)[source]
Load the ScalerBase state from a dictionary.
Note: This assumes data is already set via __init__ or set_data().
- class torchref.SolventModel(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]
Bases:
DeviceMixin,DebugMixin,ModuleSolventModel to compute solvent contribution to structure factors using Phenix-like approach.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
solvent = SolventModel() # Creates empty shell solvent.load_state_dict(torch.load('solvent.pt'))
Full initialization with model:
solvent = SolventModel(model, k_solvent=0.35, b_solvent=46.0)
- device
Device for tensor operations.
- Type:
- float_type
Floating point data type.
- Type:
- log_k_solvent
Log of solvent scattering scale factor.
- Type:
torch.nn.Parameter
- b_solvent
Solvent B-factor.
- Type:
torch.nn.Parameter
- __init__(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]
Initialize SolventModel.
If model is provided, fully initializes the solvent model. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
model (ModelFT, optional) – The atomic model used for structure factor calculations (optional for empty init).
radius (float, default 1.1) – Probe radius in Angstroms for dilation (water radius).
k_solvent (float, default 1.1) – Solvent scattering scale factor.
b_solvent (float, default 50.0) – Solvent B-factor.
erosion_radius (float, default 0.9) – Radius in Angstroms for erosion step.
transition (float, optional) – Gaussian smoothing sigma for mask edges (default: radius/4 in voxels). Avoids ringing artifacts.
optimize_phase (bool, default True) – Whether to optimize phase offset parameter.
initial_phase_offset (float, default 0.0) – Initial phase offset in radians.
verbose (int, default 1) – Verbosity level.
float_type (torch.dtype, default torch.float32) – Floating point data type.
device (torch.device, default: configured device.current) – Device for tensor operations.
- get_solvent_mask()[source]
Generate solvent mask following Phenix’s three-step process.
- Step 1 (dilation): classify voxels around each atom as protein
(inside VdW), boundary (between VdW and VdW+solvent_radius), or bulk solvent (further out). Built in chunks over atoms so peak memory is O(atom_chunk_size × N_box_voxels) rather than O(N_atoms × N_box_voxels) — critical because for typical macromolecule + grid combinations the dense form is multi-GB.
- Step 2 (symmetry expansion): transform the sparse ASU protein /
boundary voxel indices through each symop and scatter into the P1 grid masks.
- Step 3 (erosion): a boundary voxel becomes solvent if any voxel
within
erosion_radiusof it is bulk solvent. Implemented as a single F.conv3d with a precomputed spherical kernel and circular padding — replaces the previous Python-loop + per-voxel-neighbourhood expansion that itself ran out of memory on chunks of 10^6 boundary voxels.
- Returns:
Solvent mask (boolean) where True = solvent.
- Return type:
- get_rec_solvent(hkl)[source]
Compute solvent structure factors.
Uses the standard crystallographic approach: compute SFs from the solvent mask. The mask represents regions where bulk solvent scattering occurs.
- Parameters:
hkl (torch.Tensor) – Miller indices.
- Returns:
Complex solvent structure factors.
- Return type:
- forward(hkl, update_fsol=False, F_protein=None)[source]
Compute solvent contribution to structure factors at given HKL.
This method is differentiable with respect to k_solvent, b_solvent, and phase_offset parameters.
The solvent model:
Takes the binary solvent mask
Smooths it with Gaussian filter (σ=1.5 voxels) to create soft edges
Computes structure factors via FFT
Applies B-factor damping: exp(-B * s²) where s = sin(θ)/λ
If optimize_phase=True and F_protein provided: blends mask phases with protein phases phase_offset controls the blend: 0=use mask phases, ±π=use protein phases
Scales by k_solvent
- Parameters:
hkl (torch.Tensor) – Miller indices, shape (N, 3).
update_fsol (bool, default False) – Whether to update solvent structure factors.
F_protein (torch.Tensor, optional) – Protein structure factors, used for phase blending.
- Returns:
Complex solvent structure factors, shape (N,).
- Return type:
- parameters()[source]
Return an iterator over module parameters.
This is typically passed to an optimizer.
- Args:
- recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields:
Parameter: module parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- class torchref.Cell(data, *, dtype=torch.float32, device=device(type='cpu'), requires_grad=False)[source]
Bases:
DeviceMixinDataclass for crystallographic unit cells with cached derived quantities.
Stores 6 parameters: [a, b, c, alpha, beta, gamma] - a, b, c: cell lengths in Angstroms - alpha, beta, gamma: cell angles in degrees
Derived quantities (fractional_matrix, volume, etc.) are computed on first access and cached. The cache is cleared when the cell is moved to a different device or dtype.
Examples
>>> cell = Cell([50, 60, 70, 90, 90, 90]) >>> cell.volume # Computed and cached tensor(210000.) >>> cell_gpu = cell.to('cuda') # Move to GPU (returns new Cell) >>> cell_gpu.device.type 'cuda'
- __init__(data, *, dtype=torch.float32, device=device(type='cpu'), requires_grad=False)[source]
Create a new Cell.
- Parameters:
data (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma]. Can be a list, numpy array, or torch tensor.
dtype (torch.dtype, optional) – Desired data type. Defaults to the configured
dtypes.float.device (torch.device or str, optional) – Desired device. Defaults to the configured
device.current.requires_grad (bool, optional) – Whether to track gradients. Defaults to False.
- Raises:
ValueError – If data does not have exactly 6 elements.
- detach()[source]
Return a new Cell with detached tensor (no gradient tracking).
- Returns:
New Cell with detached data.
- Return type:
- clone()[source]
Return a new Cell with cloned tensor data.
- Returns:
New Cell with cloned data.
- Return type:
- property fractional_matrix: Tensor
Orthogonalization matrix B (fractional -> Cartesian).
Returns the 3x3 matrix B such that: cart = frac @ B.T
- Returns:
Shape (3, 3) orthogonalization matrix.
- Return type:
- property inv_fractional_matrix: Tensor
Fractionalization matrix B^-1 (Cartesian -> fractional).
Returns the 3x3 matrix B^-1 such that: frac = cart @ B^-1.T
- Returns:
Shape (3, 3) fractionalization matrix.
- Return type:
- property volume: Tensor
Unit cell volume in cubic Angstroms.
- Returns:
Scalar tensor with the cell volume.
- Return type:
- property reciprocal_basis_matrix: Tensor
Reciprocal basis matrix with [a*, b*, c*] as rows.
- Returns:
Shape (3, 3) matrix where rows are the reciprocal basis vectors.
- Return type:
- compute_grid_size(max_res, oversampling=3.0)[source]
Compute minimum grid dimensions for a given resolution.
Uses Shannon-Nyquist sampling criterion to determine the minimum number of grid points needed along each axis.
- Parameters:
- Returns:
Minimum grid dimensions (nx, ny, nz).
- Return type:
Examples
>>> cell = Cell([50, 60, 70, 90, 90, 90]) >>> cell.compute_grid_size(2.0) (75, 90, 105)
- tolist()[source]
Convert Cell parameters to a standard Python list.
- Returns:
List of cell parameters [a, b, c, alpha, beta, gamma].
- Return type:
- fractional_to_cartesian(frac_coords)[source]
Convert fractional coordinates to Cartesian coordinates.
- Parameters:
frac_coords (torch.Tensor) – Tensor of fractional coordinates, shape (…, 3).
- Returns:
Tensor of Cartesian coordinates, shape (…, 3).
- Return type:
- cartesian_to_fractional(cart_coords)[source]
Convert Cartesian coordinates to fractional coordinates.
- Parameters:
cart_coords (torch.Tensor) – Tensor of Cartesian coordinates, shape (…, 3).
- Returns:
Tensor of fractional coordinates, shape (…, 3).
- Return type:
- class torchref.SpaceGroup(space_group=None, dtype=torch.float32, device=device(type='cpu'))[source]
Bases:
DeviceMixin,DebugMixin,ModuleUnified space group handler for crystallographic symmetry operations.
This class combines space group normalization with symmetry operations, providing a single interface for: - Normalizing input (string, int, gemmi.SpaceGroup) in the constructor - Storing symmetry matrices and translations as PyTorch buffers - Applying symmetry operations to fractional coordinates - Grid size utilities for symmetry-compatible grids
- Parameters:
space_group (str, int, gemmi.SpaceGroup, SpaceGroup, or None) – Space group specification. Accepts: - Hermann-Mauguin symbol (e.g., ‘P21’, ‘P 21 21 21’) - Space group number (1-230) - gemmi.SpaceGroup object - Another SpaceGroup instance - None (defaults to P1)
dtype (torch.dtype, default torch.float64) – Data type for rotation matrices and translations.
device (torch.device, default: configured device.current) – Device for computation.
- matrices
Rotation matrices for all symmetry operations (registered buffer).
- Type:
torch.Tensor, shape (n_ops, 3, 3)
- translations
Translation vectors for all symmetry operations (registered buffer).
- Type:
torch.Tensor, shape (n_ops, 3)
Examples
# Create from various inputs sg = SpaceGroup('P21') sg = SpaceGroup('P 21') # Same result sg = SpaceGroup(4) # P21 by number sg = SpaceGroup(None) # Returns P1 # Access properties print(sg.name) # 'P21' (short name) print(sg.hm) # 'P 21' (Hermann-Mauguin with spaces) print(sg.number) # 4 # Apply symmetry operations coords = torch.tensor([[0.1, 0.2, 0.3]]) transformed = sg(coords) # Apply all symmetry operations # Grid utilities req = sg.get_grid_requirements() suggested = sg.suggest_grid_size((131, 163, 148))
- property gemmi: SpaceGroup
Access a gemmi.SpaceGroup object (created on demand, not stored).
- property spacegroup: SpaceGroup
Alias for gemmi property (backward compatibility).
- property space_group: SpaceGroup
Alias for gemmi property (backward compatibility).
- apply(xyz_fractional, apply_translation=True)[source]
Apply symmetry operations to fractional coordinates (rotation + translation).
For real space coordinates, applies the full symmetry operation: x’ = R·x + t
- Parameters:
xyz_fractional (torch.Tensor) – Input tensor of shape (N, 3) representing fractional coordinates.
- Returns:
Transformed coordinates of shape (N, 3, ops) where ops is the number of symmetry operations.
- Return type:
See also
apply_to_hklFor reciprocal space (Miller indices), rotation only.
- apply_to_hkl(hkl)[source]
Apply symmetry operations to Miller indices (rotation only, no translation).
For reciprocal space, only the rotational part of symmetry operations applies to Miller indices: h’ = R·h. The translation vector affects the phase of structure factors, not the indices themselves.
- Parameters:
hkl (torch.Tensor) – Input tensor of shape (N, 3) representing Miller indices.
- Returns:
Transformed Miller indices of shape (N, 3, ops) where ops is the number of symmetry operations.
- Return type:
See also
applyFor real space coordinates (rotation + translation).
- expand_coords_to_P1(xyz_fractional)[source]
Expand fractional coordinates by applying all symmetry operations.
- Parameters:
xyz_fractional (torch.Tensor) – Input tensor of shape (N, 3) representing fractional coordinates.
- Returns:
Expanded coordinates of shape (N * ops, 3).
- Return type:
- get_grid_requirements()[source]
Analyze symmetry operations to determine grid size requirements.
- Returns:
{‘nx_mod’: int, ‘ny_mod’: int, ‘nz_mod’: int} Required divisibility for each axis.
- Return type:
Examples
sg = SpaceGroup('P21') req = sg.get_grid_requirements() print(req) # {'nx_mod': 1, 'ny_mod': 2, 'nz_mod': 1}
- check_grid_compatibility(grid_shape)[source]
Check if a grid size is compatible with the symmetry operations.
- Parameters:
- Returns:
Dictionary with keys: - ‘compatible’ : bool - True if grid satisfies all requirements - ‘symmetry_compatible’ : bool - True if grid satisfies symmetry - ‘fft_friendly’ : bool - True if all dimensions are FFT-friendly - ‘can_use_direct_indexing’ : bool - True if no interpolation needed - ‘issues’ : list of str - Descriptions of incompatibilities - ‘requirements’ : dict - Required divisibility
- Return type:
Examples
sg = SpaceGroup('P21') result = sg.check_grid_compatibility((131, 163, 148)) print(result['compatible']) # False print(result['issues']) # ['ny=163 not divisible by 2']
- suggest_grid_size(min_grid_shape, make_fft_friendly=True)[source]
Suggest an optimal grid size that satisfies symmetry requirements.
- Parameters:
- Returns:
Suggested grid dimensions (nx, ny, nz).
- Return type:
Examples
sg = SpaceGroup('P21') suggested = sg.suggest_grid_size((131, 163, 148)) print(suggested) # (135, 164, 150) or similar
- class torchref.Map(data, model, gridsize=None, map_type='2mFo-DFc', device=None)[source]
Bases:
DeviceMixinCrystallographic electron density map.
- Parameters:
data (ReflectionData) – Observed reflection data with amplitudes, hkl, cell, and spacegroup.
model (ModelFT) – Model for computing Fcalc (structure factors).
gridsize (tuple of int, optional) – Grid dimensions (nx, ny, nz). If None, determined automatically from cell parameters and resolution.
map_type (str, optional) – Type of map to compute. One of
"2mFo-DFc"or"Fcalc". Default is"2mFo-DFc".
- VALID_MAP_TYPES = ('2mFo-DFc', 'Fcalc')
- class torchref.DifferenceMap(data, data_reference, model, gridsize=None, device=None)[source]
Bases:
MapIsomorphous difference map between two datasets.
Scales both datasets to a common reference using
DatasetCollection, then computes difference Fourier coefficients:DF * exp(i * phi_calc)whereDF = F_data - F_reference.- Parameters:
data (ReflectionData) – Reflection data for the perturbed state (e.g., light, derivative).
data_reference (ReflectionData) – Reflection data for the reference state (e.g., dark, native).
model (ModelFT) – Model for computing phases.
gridsize (tuple of int, optional) – Grid dimensions (nx, ny, nz). If None, determined automatically.
- class torchref.DeviceMixin[source]
Bases:
objectUnified device/dtype movement.
Inherit alongside
nn.Module(place beforenn.Modulein the MRO):class Foo(DeviceMixin, nn.Module): ...
Or use on a plain Python class / dataclass:
@dataclass class Bar(DeviceMixin): data: torch.Tensor
All of
.to(),.cuda(),.cpu(),.float(),.double(),.half()route through_apply(), which:invokes
nn.Module._applywhen applicable so parameters, buffers and child modules are moved by the standard PyTorch path,walks
self.__dict__to pick up plain tensor attributes, nested containers and non-Module sub-objects,calls
reset_forward_cache()andreset_cache()if either is defined.
Subpackages
- torchref.base package
- Submodules (New Organization)
- Legacy Submodules (For Backward Compatibility)
FrenchWilsonCachedRadiusMaskReciprocalSymmetryExtractorcartesian_to_fractional_torch()fractional_to_cartesian_torch()get_fractional_matrix()get_inv_fractional_matrix_torch()cartesian_to_fractional()fractional_to_cartesian()get_inv_fractional_matrix()convert_coords_to_fractional()smallest_diff()smallest_diff_aniso()reciprocal_basis_matrix()reciprocal_basis_matrix_numpy()get_scattering_vectors()get_scattering_vectors_numpy()get_s()get_d_spacing()compute_d_spacing_batch()generate_possible_hkl()place_on_grid()extract_structure_factor_from_grid()apply_translation_phase()interpolate_structure_factor_from_grid()interpolate_complex_from_grid()trilinear_interpolate_patterson()compute_symmetry_equivalent_hkls()compute_translation_phases()extract_structure_factors_with_symmetry()interpolate_for_rotation()smooth_reciprocal_grid()iso_structure_factor_torched()iso_structure_factor_torched_no_complex()aniso_structure_factor_torched()aniso_structure_factor_torched_no_complex()anharmonic_correction()anharmonic_correction_no_complex()core_deformation()multiplication_quasi_complex_tensor()vectorized_add_to_map()vectorized_add_to_map_aniso()scatter_add_nd()scatter_add_nd_super_slow()find_relevant_voxels()excise_angstrom_radius_around_coord()add_to_solvent_mask()add_to_phenix_mask()find_solvent_voids()fft()ifft()get_real_grid()find_grid_size()get_real_grid_numpy()get_grids()put_hkl_on_grid()get_scattering_factors()get_scattering_factors_unique()get_parametrization_for_elements()calc_scattering_factors_paramtetrization()compute_radial_shells()assign_to_shells()compute_anisotropy_correction()compute_shell_cv()fit_anisotropy_correction()apply_anisotropy_correction()F_squared_to_E_values()rotate_coords_torch()rotate_coords_numpy()axis_angle_to_rotation_matrix()rotation_matrix_to_axis_angle()quaternion_to_rotation_matrix()random_rotation_uniform()superpose_vectors_robust_torch()superpose_vectors_robust()align_torch()align_pdbs()get_alignment_matrix()apply_transformation()apply_transformation_numpy()invert_transformation_matrix()get_rfactor_torch()get_rfactor()rfactor()get_rfactors()bin_wise_rfactors()calc_outliers()calc_outliers_numpy()nll_xray()nll_xray_sum()nll_xray_lognormal()log_loss()estimate_sigma_I()estimate_sigma_F()gaussian_to_lognormal_sigma()gaussian_to_lognormal_mu()compute_metric_tensor()precompute_fractional_coords()warmup()get_cache_dir()clear_cache()warmup_cuda_operations()get_cached_radius_offsets()- Subpackages
- torchref.base.alignment package
superpose_vectors_robust_torch()superpose_vectors_robust()align_torch()get_alignement_matrix()align_pdbs()get_alignment_matrix()apply_transformation()apply_transformation_numpy()invert_transformation_matrix()rotate_coords_torch()rotate_coords_numpy()axis_angle_to_rotation_matrix()rotation_matrix_to_axis_angle()quaternion_to_rotation_matrix()random_rotation_uniform()rotation_matrix_euler_zyz()compute_radial_shells()assign_to_shells()compute_anisotropy_correction()compute_shell_cv()fit_anisotropy_correction()apply_anisotropy_correction()F_squared_to_E_values()- Submodules
- torchref.base.chain_closure package
- torchref.base.coordinates package
cartesian_to_fractional_torch()fractional_to_cartesian_torch()get_fractional_matrix()get_inv_fractional_matrix_torch()cartesian_to_fractional()fractional_to_cartesian()get_fractional_matrix_numpy()get_inv_fractional_matrix()convert_coords_to_fractional()smallest_diff()smallest_diff_aniso()- Submodules
- torchref.base.direct_summation package
- torchref.base.electron_density package
- torchref.base.fourier package
- torchref.base.kernels package
vectorized_add_to_map()build_electron_density()compute_metric_tensor()precompute_fractional_coords()warmup()get_cache_dir()clear_cache()warmup_cuda_operations()CachedRadiusMaskget_cached_radius_offsets()vectorized_add_to_map_optimized()fused_add_to_map_gpu()fused_find_and_place_atoms()separable_density_gpu()- Submodules
- torchref.base.metrics package
- torchref.base.reciprocal package
reciprocal_basis_matrix()reciprocal_basis_matrix_numpy()get_scattering_vectors()get_scattering_vectors_numpy()get_s()generate_possible_hkl()get_d_spacing()compute_d_spacing_batch()place_on_grid()extract_structure_factor_from_grid()apply_translation_phase()interpolate_structure_factor_from_grid()interpolate_complex_from_grid()trilinear_interpolate_patterson()interpolate_for_rotation()smooth_reciprocal_grid()compute_symmetry_equivalent_hkls()compute_translation_phases()extract_structure_factors_with_symmetry()ReciprocalSymmetryExtractor- Submodules
- torchref.base.scattering package
get_scattering_factors_unique()get_scattering_factors()get_scattering_itc92()calc_scattering_factors_paramtetrization()get_parameterization()get_parameterization_extended()get_parametrization_for_elements()get_parametrization_atom()linear_interpolation()load_scattering_table()get_scattering_params_by_z()get_element_to_z_mapping()get_z_to_element_mapping()get_scattering_params_for_ion()elements_to_z()get_anomalous_correction()get_significant_elements()get_anomalous_corrections_by_indices()- Submodules
- torchref.base.targets package
- Mapping
adp_kl_math()adp_locality_math()adp_simu_math()angle_math()bhattacharyya_xray_loss_math()bond_math()chiral_math()gaussian_xray_loss_math()ls_xray_loss_math()ml_xray_loss_math()nonbonded_heavy_math()planarity_math()ramachandran_math()torsion_omega_math()torsion_unimodal_math()- Subpackages
- Submodules
- torchref.base.alignment package
- Submodules
- torchref.base.CCTBX_related module
- torchref.base.french_wilson module
- torchref.base.get_scattering_factor_torch module
- torchref.base.math_numpy module
cartesian_to_fractional()fractional_to_cartesian()get_fractional_matrix()get_inv_fractional_matrix()convert_coords_to_fractional()reciprocal_basis_matrix()get_scattering_vectors()get_s()get_real_grid()get_grids()put_hkl_on_grid()rotate_coords_numpy()superpose_vectors_robust()align_pdbs()get_alignment_matrix()apply_transformation()invert_transformation_matrix()get_rfactor()calc_outliers()get_res_for_dataset()
- torchref.base.math_torch module
cartesian_to_fractional_torch()fractional_to_cartesian_torch()get_fractional_matrix()get_inv_fractional_matrix_torch()smallest_diff()smallest_diff_aniso()reciprocal_basis_matrix()get_scattering_vectors()get_d_spacing()place_on_grid()extract_structure_factor_from_grid()apply_translation_phase()interpolate_structure_factor_from_grid()interpolate_complex_from_grid()trilinear_interpolate_patterson()iso_structure_factor_torched()iso_structure_factor_torched_no_complex()aniso_structure_factor_torched()aniso_structure_factor_torched_no_complex()anharmonic_correction()anharmonic_correction_no_complex()core_deformation()multiplication_quasi_complex_tensor()vectorized_add_to_map()vectorized_add_to_map_aniso()scatter_add_nd()scatter_add_nd_super_slow()find_relevant_voxels()excise_angstrom_radius_around_coord()add_to_solvent_mask()add_to_phenix_mask()find_solvent_voids()fft()ifft()get_real_grid()find_grid_size()rotate_coords_torch()axis_angle_to_rotation_matrix()rotation_matrix_to_axis_angle()quaternion_to_rotation_matrix()random_rotation_uniform()superpose_vectors_robust_torch()align_torch()get_alignement_matrix()apply_transformation()get_rfactor_torch()rfactor()get_rfactors()bin_wise_rfactors()calc_outliers()nll_xray()nll_xray_sum()nll_xray_lognormal()log_loss()estimate_sigma_I()estimate_sigma_F()gaussian_to_lognormal_sigma()gaussian_to_lognormal_mu()U_to_matrix()deterministic_tensor_digest()hash_tensors()french_wilson_conversion()
- torchref.cli package
- torchref.io package
- High-level API
CrystalDatasetCrystalDataset.hklCrystalDataset.FCrystalDataset.F_sigmaCrystalDataset.ICrystalDataset.I_sigmaCrystalDataset.rfree_flagsCrystalDataset.resolutionCrystalDataset.bin_indicesCrystalDataset.outlier_flagsCrystalDataset.phaseCrystalDataset.fomCrystalDataset.ECrystalDataset.E_squaredCrystalDataset.F_squared_correctedCrystalDataset.U_anisoCrystalDataset.radial_shell_indicesCrystalDataset.cellCrystalDataset.spacegroupCrystalDataset.deviceCrystalDataset.verboseCrystalDataset.rfree_sourceCrystalDataset.amplitude_sourceCrystalDataset.intensity_sourceCrystalDataset.phase_sourceCrystalDataset.wilson_bCrystalDataset.wilson_b_structureCrystalDataset.wilson_b_solventCrystalDataset.wilson_k_solCrystalDataset.outlier_detection_paramsCrystalDataset.__post_init__()CrystalDataset.save_state()CrystalDataset.load_state()CrystalDataset.__len__()CrystalDataset.__repr__()CrystalDataset.spacegroup_nameCrystalDataset.spacegroup_hmCrystalDataset.spacegroup_numberCrystalDataset.__init__()
ReflectionDataReflectionData.hklReflectionData.FReflectionData.F_sigmaReflectionData.IReflectionData.I_sigmaReflectionData.rfree_flagsReflectionData.cellReflectionData.spacegroupReflectionData.resolutionReflectionData.wilson_bReflectionData.sourceReflectionData.datasetReflectionData.last_opReflectionData.readerReflectionData.__post_init__()ReflectionData.load()ReflectionData.from_tensors()ReflectionData.load_mtz()ReflectionData.load_cif()ReflectionData.list_cif_data_blocks()ReflectionData.get_bins()ReflectionData.mean_res_per_bin()ReflectionData.mean_F_per_bin()ReflectionData.mean_sigma_per_bin()ReflectionData.regenerate_rfree_flags()ReflectionData.get_structure_factors()ReflectionData.get_structure_factors_with_sigma()ReflectionData.get_hkl()ReflectionData.filter_by_resolution()ReflectionData.get_mask()ReflectionData.cut_res()ReflectionData.get_rfree_masks()ReflectionData.get_work_set()ReflectionData.get_test_set()ReflectionData.get_max_res()ReflectionData.get_min_res()ReflectionData.__len__()ReflectionData.d_minReflectionData.__repr__()ReflectionData.get_valid_mask()ReflectionData.data_indexed()ReflectionData.__call__()ReflectionData.data_fill_masked()ReflectionData.__getitem__()ReflectionData.__select__()ReflectionData.sanitize_F()ReflectionData.check_all_data_types()ReflectionData.validate_hkl()ReflectionData.find_outliers()ReflectionData.get_log_ratio()ReflectionData.get_outlier_statistics()ReflectionData.unpack_one()ReflectionData.flag_suspicious_sigma()ReflectionData.dump()ReflectionData.write_mtz()ReflectionData.centricReflectionData.calc_patterson()ReflectionData.possible_hkl()ReflectionData.remap()ReflectionData.fill()ReflectionData.expand_to_p1()ReflectionData.reduce_to_spacegroup()ReflectionData.canonicalize()ReflectionData.get_scattering_vectors()ReflectionData.get_radial_shells()ReflectionData.fit_anisotropy()ReflectionData.setup_anisotropy()ReflectionData.apply_anisotropy_correction()ReflectionData.compute_e_values()ReflectionData.setup_scale()ReflectionData.get_corrected_data()ReflectionData.parameters()ReflectionData.__init__()
DatasetCollectionDatasetCollection.hklDatasetCollection.n_datasetsDatasetCollection.add_dataset()DatasetCollection.hklDatasetCollection.datasetsDatasetCollection.n_datasetsDatasetCollection.reference_datasetDatasetCollection.spacegroupDatasetCollection.__getitem__()DatasetCollection.__iter__()DatasetCollection.__len__()DatasetCollection.__contains__()DatasetCollection.__call__()DatasetCollection.scale()DatasetCollection.keys()DatasetCollection.values()DatasetCollection.items()DatasetCollection.__init__()DatasetCollection.get()DatasetCollection.__repr__()
FcalcDatasetFcalcDataset.spacegroupFcalcDataset.fcalcFcalcDataset.fcalc_ampFcalcDataset.fcalc_phaseFcalcDataset.from_cell_and_resolution()FcalcDataset.set_fcalc()FcalcDataset.write_mtz()FcalcDataset.write_mtz_as_fobs()FcalcDataset.__repr__()FcalcDataset.spacegroup_nameFcalcDataset.spacegroup_hmFcalcDataset.spacegroup_numberFcalcDataset.__init__()
MTZReaderPDBReaderCIFReaderCIFReader.dataCIFReader.filepathCIFReader.available_blocksCIFReader.__init__()CIFReader.from_string()CIFReader.load()CIFReader.write()CIFReader.__getitem__()CIFReader.__setitem__()CIFReader.__contains__()CIFReader.__len__()CIFReader.keys()CIFReader.values()CIFReader.items()CIFReader.get()CIFReader.__repr__()CIFReader.summary()
ReflectionCIFReaderReflectionCIFReader.__init__()ReflectionCIFReader.read()ReflectionCIFReader.__call__()ReflectionCIFReader.get_reflection_data()ReflectionCIFReader.has_miller_indices()ReflectionCIFReader.has_amplitudes()ReflectionCIFReader.has_intensities()ReflectionCIFReader.has_phases()ReflectionCIFReader.has_rfree_flags()ReflectionCIFReader.get_miller_indices()ReflectionCIFReader.get_amplitudes()ReflectionCIFReader.get_intensities()ReflectionCIFReader.get_cell_parameters()ReflectionCIFReader.get_space_group()
ModelCIFReaderModelCIFReader.__init__()ModelCIFReader.read()ModelCIFReader.__call__()ModelCIFReader.get_atom_data()ModelCIFReader.get_atom_data_by_model()ModelCIFReader.get_cell_parameters()ModelCIFReader.get_space_group()ModelCIFReader.has_coordinates()ModelCIFReader.has_cell_parameters()ModelCIFReader.has_space_group()ModelCIFReader.has_occupancy()ModelCIFReader.has_bfactor()ModelCIFReader.has_anisotropic_data()ModelCIFReader.get_coordinates()ModelCIFReader.get_atom_info()
RestraintCIFReaderRestraintCIFReader.__init__()RestraintCIFReader.get_all_restraints()RestraintCIFReader.get_compound_restraints()RestraintCIFReader.get_bond_restraints()RestraintCIFReader.get_compound_id()RestraintCIFReader.has_bond_restraints()RestraintCIFReader.has_angle_restraints()RestraintCIFReader.has_torsion_restraints()RestraintCIFReader.has_plane_restraints()RestraintCIFReader.has_chirality_restraints()
DataRouterDataRouter.filepathDataRouter.verboseDataRouter.data_typeDataRouter.file_formatDataRouter.readerDataRouter.MTZ_EXTENSIONSDataRouter.PDB_EXTENSIONSDataRouter.CIF_EXTENSIONSDataRouter.__init__()DataRouter.get_reader()DataRouter.get_data()DataRouter.route()DataRouter.__repr__()DataRouter.__str__()
DataRouterErrorIHMEnsembleMappingIHMEnsembleMapping.statesIHMEnsembleMapping.model_groupsIHMEnsembleMapping.cellIHMEnsembleMapping.spacegroupIHMEnsembleMapping.atom_data_per_stateIHMEnsembleMapping.get_state_ids()IHMEnsembleMapping.get_timepoint_names()IHMEnsembleMapping.get_fractions_for_group()IHMEnsembleMapping.identify_dark_group()IHMEnsembleMapping.get_state_by_id()IHMEnsembleMapping.get_group_by_name()IHMEnsembleMapping.validate()IHMEnsembleMapping.__init__()
IHMStateInfoIHMModelGroupInfoRefinementMetadataRefinementMetadata.programRefinementMetadata.program_versionRefinementMetadata.refinement_methodRefinementMetadata.resolution_highRefinementMetadata.resolution_lowRefinementMetadata.n_reflections_workRefinementMetadata.n_reflections_testRefinementMetadata.n_reflections_allRefinementMetadata.percent_freeRefinementMetadata.r_workRefinementMetadata.r_freeRefinementMetadata.b_mean_overallRefinementMetadata.b_minRefinementMetadata.b_maxRefinementMetadata.rmsd_bond_lengthsRefinementMetadata.rmsd_bond_anglesRefinementMetadata.n_atoms_totalRefinementMetadata.n_atoms_proteinRefinementMetadata.n_atoms_solventRefinementMetadata.solvent_model_ksolRefinementMetadata.solvent_model_bsolRefinementMetadata.cellRefinementMetadata.spacegroupRefinementMetadata.titleRefinementMetadata.authorsRefinementMetadata.passthrough_pdb_remarksRefinementMetadata.passthrough_cif_categoriesRefinementMetadata.custom_remarksRefinementMetadata.to_dict()RefinementMetadata.from_dict()RefinementMetadata.from_refinement()RefinementMetadata.from_pdb_file()RefinementMetadata.from_cif_file()RefinementMetadata.merge()RefinementMetadata.render_pdb_header()RefinementMetadata.render_cif_categories()RefinementMetadata.__init__()
MTZPDB- Subpackages
- Submodules
- torchref.kinetic package
- Key Classes
KineticModelKineticModel.__init__()KineticModel.forward()KineticModel.set_baseline()KineticModel.get_baselines()KineticModel.get_rate_constants()KineticModel.set_rate_constant()KineticModel.get_efficiencies()KineticModel.get_effective_rates()KineticModel.get_time_constants()KineticModel.parameters()KineticModel.print_parameters()KineticModel.plot_occupancies()KineticModel.visualize()
occupancy_unrestrainedoccupancies_kineticsoccupancies_kinetics.__init__()occupancies_kinetics.forward()occupancies_kinetics.get_regularization_loss()occupancies_kinetics.get_rate_constants()occupancies_kinetics.get_efficiencies()occupancies_kinetics.get_time_constants()occupancies_kinetics.set_rate_constant()occupancies_kinetics.freeze_rates()occupancies_kinetics.unfreeze_rates()occupancies_kinetics.freeze_efficiencies()occupancies_kinetics.unfreeze_efficiencies()occupancies_kinetics.freeze_instrument()occupancies_kinetics.unfreeze_instrument()occupancies_kinetics.get_parameter_groups()occupancies_kinetics.print_parameters()occupancies_kinetics.plot_occupancies()occupancies_kinetics.plot_comparison()occupancies_kinetics.state_dict_kinetics()occupancies_kinetics.load_from_kinetics_state()
occupancies_kinetics_multiexperimentModelCollectionModelCollection.__init__()ModelCollection.add_timepoint()ModelCollection.add_dark()ModelCollection.from_kinetics()ModelCollection.from_ihm()ModelCollection.write_ihm()ModelCollection.keys()ModelCollection.values()ModelCollection.items()ModelCollection.get()ModelCollection.dark_keyModelCollection.dark_modelModelCollection.base_modelsModelCollection.n_base_modelsModelCollection.timepoint_namesModelCollection.cellModelCollection.spacegroupModelCollection.deviceModelCollection.get_all_fractions()ModelCollection.get_fractions_matrix()ModelCollection.freeze_all_fractions()ModelCollection.unfreeze_all_fractions()ModelCollection.freeze_structures()ModelCollection.unfreeze_structures()ModelCollection.write_pdbs()
KineticRefinementKineticRefinement.__init__()KineticRefinement.setup()KineticRefinement.set_weights()KineticRefinement.get_loss()KineticRefinement.print_loss_summary()KineticRefinement.refine()KineticRefinement.refine_structures()KineticRefinement.refine_fractions()KineticRefinement.refine_alternating()KineticRefinement.refit_kinetic_prior()KineticRefinement.refine_kinetics()KineticRefinement.write_pdbs()
CollectionDifferenceTargetCollectionMLTargetMultiModelGeometryTargetMultiModelADPTargetKineticPriorTarget- Submodules
- torchref.maps package
- torchref.model package
- Classes
FFT()SfFFTSfFFT.cellSfFFT.spacegroupSfFFT.symmetrySfFFT.max_resSfFFT.radius_angstromSfFFT.gridsizeSfFFT.real_space_gridSfFFT.voxel_sizeSfFFT.map_symmetrySfFFT.__init__()SfFFT.cellSfFFT.spacegroupSfFFT.symmetrySfFFT.fractional_matrixSfFFT.inv_fractional_matrixSfFFT.set_cell_and_spacegroup()SfFFT.compute_optimal_gridsize()SfFFT.compute_real_space_grid()SfFFT.setup_grid()SfFFT.build_density_map()SfFFT.map_to_structure_factors()SfFFT.compute_structure_factors()SfFFT.reset_cache()SfFFT.copy()
SfDSInternalCoordinateTensorInternalCoordinateTensor.n_atomsInternalCoordinateTensor.n_chainsInternalCoordinateTensor.max_depthInternalCoordinateTensor.bond_lengthsInternalCoordinateTensor.anglesInternalCoordinateTensor.torsionsInternalCoordinateTensor.chain_positionsInternalCoordinateTensor.chain_orientationsInternalCoordinateTensor.__init__()InternalCoordinateTensor.dtypeInternalCoordinateTensor.deviceInternalCoordinateTensor.to()InternalCoordinateTensor.cuda()InternalCoordinateTensor.cpu()InternalCoordinateTensor.forward_slow()InternalCoordinateTensor.forward()InternalCoordinateTensor.shake()InternalCoordinateTensor.fix()InternalCoordinateTensor.freeze()InternalCoordinateTensor.refine()InternalCoordinateTensor.unfreeze()InternalCoordinateTensor.fix_all()InternalCoordinateTensor.freeze_all()InternalCoordinateTensor.refine_all()InternalCoordinateTensor.unfreeze_all()InternalCoordinateTensor.n_refinableInternalCoordinateTensor.n_fixedInternalCoordinateTensor.forward_parallel()
MixedModelMixedModel.modelsMixedModel.fraction_paramsMixedModel.__init__()MixedModel.fractionsMixedModel.cellMixedModel.spacegroupMixedModel.deviceMixedModel.dtype_floatMixedModel.real_space_gridMixedModel.fftMixedModel.gridsizeMixedModel.map_symmetryMixedModel.inv_fractional_matrixMixedModel.fractional_matrixMixedModel.setup_grid()MixedModel.get_radius()MixedModel.build_complete_map()MixedModel.freeze_fractions()MixedModel.unfreeze_fractions()MixedModel.forward()MixedModel.get_individual_fcalc()MixedModel.copy()MixedModel.__repr__()MixedModel.write_ihm()MixedModel.get_vdw_radii()MixedModel.xyz()
ModelModel.xyzModel.adpModel.uModel.occupancyModel.pdbModel.cellModel.spacegroupModel.symmetryModel.initializedModel.__init__()Model.__bool__()Model.exclude_H_from_sfModel.cellModel.spacegroupModel.symmetryModel.inv_fractional_matrixModel.fractional_matrixModel.recBModel.ZModel.get_P1_parameters_iso()Model.get_MD_parameters()Model.parametrizationModel.get_scattering_params_iso()Model.get_scattering_params_aniso()Model.set_restraints_cif()Model.restraintsModel.bond_deviations()Model.angle_deviations()Model.torsion_deviations_with_sigmas()Model.load()Model.load_pdb()Model.load_cif()Model.chain_sequencesModel.get_chain_residues()Model.update_pdb()Model.get_vdw_radii()Model.to()Model.copy()Model.write_pdb()Model.write_cif()Model.get_iso()Model.set_default_masks()Model.PARAM_TYPESModel.parameters_of_types()Model.freeze()Model.freeze_all()Model.unfreeze_all()Model.unfreeze()Model.update_mask_from_selection()Model.apply_mask_to_parameter()Model.freeze_selection()Model.unfreeze_selection()Model.get_aniso()Model.parameters()Model.named_mixed_tensors()Model.print_parameters_info()Model.register_alternative_conformations()Model.shake_coords()Model.shake_adp()Model.generate_hydrogens()Model.strip_altlocs()Model.strip_hydrogens()Model.hydrogenate()Model.adp_loss()Model.adp_nll_loss()Model.adp_nll_loss_per_atom()Model.adp_kl_divergence_loss()Model.state_dict()Model.save_state()Model.load_state()Model.create_from_state_dict()Model.get_selection_mask()Model.select()Model.xyz_fractional()Model.rotate()Model.translate()Model.get_centroid()Model.use_internal_coordinates()
ModelFTModelFT.max_resModelFT.radius_angstromModelFT.wavelengthModelFT.anomalous_thresholdModelFT.gridsizeModelFT.real_space_gridModelFT.mapModelFT.parametrizationModelFT.map_symmetryModelFT.__init__()ModelFT.cellModelFT.spacegroupModelFT.load_pdb()ModelFT.select()ModelFT.load_cif()ModelFT.setup_gridsize()ModelFT.AModelFT.BModelFT.gridsizeModelFT.real_space_gridModelFT.voxel_sizeModelFT.map_symmetryModelFT.get_iso()ModelFT.get_aniso()ModelFT.setup_grid()ModelFT.get_radius()ModelFT.build_complete_map()ModelFT.build_initial_map()ModelFT.save_map()ModelFT.get_map_statistics()ModelFT.rebuild_map()ModelFT.update_pdb()ModelFT.reset_cache()ModelFT.invalidate_cache()ModelFT.get_structure_factor()ModelFT.fftModelFT.forward()ModelFT.copy()ModelFT.state_dict()ModelFT.create_from_state_dict()
MixedTensorMixedTensor.refinable_maskMixedTensor.fixed_maskMixedTensor.fixed_valuesMixedTensor.refinable_paramsMixedTensor.__init__()MixedTensor.forward()MixedTensor.__getitem__()MixedTensor.__setitem__()MixedTensor.set()MixedTensor.shapeMixedTensor.dtypeMixedTensor.deviceMixedTensor.get_refinable_count()MixedTensor.get_fixed_count()MixedTensor.update_fixed_values()MixedTensor.update_refinable_mask()MixedTensor.detach()MixedTensor.clone()MixedTensor.copy()MixedTensor.clip()MixedTensor.to()MixedTensor.refine()MixedTensor.fix()MixedTensor.refine_all()MixedTensor.fix_all()MixedTensor.nameMixedTensor.__str__()MixedTensor.parameters()
PositiveMixedTensorPassThroughTensorOccupancyTensorOccupancyTensor.expansion_maskOccupancyTensor.linked_occ_sizesOccupancyTensor.collapse_countsOccupancyTensor.__init__()OccupancyTensor.forward()OccupancyTensor.shapeOccupancyTensor.collapsed_shapeOccupancyTensor.clamp()OccupancyTensor.set_group_occupancy()OccupancyTensor.get_group_occupancy()OccupancyTensor.freeze()OccupancyTensor.unfreeze()OccupancyTensor.freeze_all()OccupancyTensor.unfreeze_all()OccupancyTensor.get_refinable_atoms()OccupancyTensor.get_frozen_atoms()OccupancyTensor.get_refinable_count()OccupancyTensor.get_fixed_count()OccupancyTensor.update_refinable_mask()OccupancyTensor.from_residue_groups()OccupancyTensor.copy()
SegmentedInternalCoordinateTensorSegmentedInternalCoordinateTensor.n_atomsSegmentedInternalCoordinateTensor.n_segmentsSegmentedInternalCoordinateTensor.max_depthSegmentedInternalCoordinateTensor.bond_lengthsSegmentedInternalCoordinateTensor.anglesSegmentedInternalCoordinateTensor.torsionsSegmentedInternalCoordinateTensor.segment_positionsSegmentedInternalCoordinateTensor.segment_orientationsSegmentedInternalCoordinateTensor.AA_NAMESSegmentedInternalCoordinateTensor.__init__()SegmentedInternalCoordinateTensor.dtypeSegmentedInternalCoordinateTensor.deviceSegmentedInternalCoordinateTensor.forward()SegmentedInternalCoordinateTensor.shake()SegmentedInternalCoordinateTensor.fix()SegmentedInternalCoordinateTensor.freeze()SegmentedInternalCoordinateTensor.refine()SegmentedInternalCoordinateTensor.unfreeze()SegmentedInternalCoordinateTensor.fix_all()SegmentedInternalCoordinateTensor.freeze_all()SegmentedInternalCoordinateTensor.refine_all()SegmentedInternalCoordinateTensor.unfreeze_all()SegmentedInternalCoordinateTensor.n_refinableSegmentedInternalCoordinateTensor.n_fixed
ClosedSegmentedInternalCoordinateTensorClosedSegmentedInternalCoordinateTensor.__init__()ClosedSegmentedInternalCoordinateTensor.dtypeClosedSegmentedInternalCoordinateTensor.deviceClosedSegmentedInternalCoordinateTensor.forward()ClosedSegmentedInternalCoordinateTensor.shake()ClosedSegmentedInternalCoordinateTensor.fix()ClosedSegmentedInternalCoordinateTensor.freeze()ClosedSegmentedInternalCoordinateTensor.refine()ClosedSegmentedInternalCoordinateTensor.unfreeze()ClosedSegmentedInternalCoordinateTensor.fix_all()ClosedSegmentedInternalCoordinateTensor.freeze_all()ClosedSegmentedInternalCoordinateTensor.refine_all()ClosedSegmentedInternalCoordinateTensor.unfreeze_all()ClosedSegmentedInternalCoordinateTensor.n_refinableClosedSegmentedInternalCoordinateTensor.n_fixedClosedSegmentedInternalCoordinateTensor.closure_residualsClosedSegmentedInternalCoordinateTensor.max_closure_gap
ModelCollectionModelCollection.__init__()ModelCollection.add_timepoint()ModelCollection.add_dark()ModelCollection.from_kinetics()ModelCollection.from_ihm()ModelCollection.write_ihm()ModelCollection.keys()ModelCollection.values()ModelCollection.items()ModelCollection.get()ModelCollection.dark_keyModelCollection.dark_modelModelCollection.base_modelsModelCollection.n_base_modelsModelCollection.timepoint_namesModelCollection.cellModelCollection.spacegroupModelCollection.deviceModelCollection.get_all_fractions()ModelCollection.get_fractions_matrix()ModelCollection.freeze_all_fractions()ModelCollection.unfreeze_all_fractions()ModelCollection.freeze_structures()ModelCollection.unfreeze_structures()ModelCollection.write_pdbs()
- Submodules
- torchref.model.closed_segmented_internal_coordinates module
- torchref.model.internal_coordinates module
- torchref.model.mixed_model module
- torchref.model.model module
- torchref.model.model_collection module
- torchref.model.model_ft module
- torchref.model.parameter_wrappers module
- torchref.model.segmented_internal_coordinates module
- torchref.model.sf_ds module
- torchref.model.sf_fft module
- torchref.refinement package
RefinementRefinement.deviceRefinement.verboseRefinement.reflection_dataRefinement.modelRefinement.scalerRefinement.weighterRefinement.__init__()Refinement.set_xray_target_mode()Refinement.dataRefinement.loss_stateRefinement.loggerRefinement.reset_loss_state()Refinement.get_scales()Refinement.setup_scaler()Refinement.parameters()Refinement.get_fcalc()Refinement.get_fcalc_scaled()Refinement.adp_loss()Refinement.get_F_calc()Refinement.get_F_calc_scaled()Refinement.nll_xray()Refinement.xray_loss_work()Refinement.xray_loss_test()Refinement.bond_loss()Refinement.angle_loss()Refinement.torsion_loss()Refinement.geometry_loss()Refinement.loss()Refinement.setup_component_weighting()Refinement.populate_state_meta()Refinement.update_weights()Refinement.create_loss_state()Refinement.complete_loss_state()Refinement.xray_loss()Refinement.restraints_loss()Refinement.collect_metrics()Refinement.add_target_info_to_state()Refinement.get_rfactor()Refinement.update_outliers()Refinement.plot_fcalc_vs_fobs()Refinement.write_out_mtz()Refinement.collect_deposition_metadata()Refinement.write_out_pdb()Refinement.write_out_cif()Refinement.save_state()Refinement.load_state()Refinement.create_from_state_dict()
LBFGSRefinementLBFGSRefinement.target_modeLBFGSRefinement.LBFGS_DEFAULTSLBFGSRefinement.__init__()LBFGSRefinement.xray_loss()LBFGSRefinement.refine_scaler()LBFGSRefinement.refine_xyz()LBFGSRefinement.refine_adp()LBFGSRefinement.refine_joint()LBFGSRefinement.run_training_trajectory()LBFGSRefinement.run_training_trajectory_joint()LBFGSRefinement.refine()LBFGSRefinement.refine_everything()
LossStateLossState.deviceLossState.targetsLossState.weightsLossState.historyLossState.metaLossState.deviceLossState.targetsLossState.weightsLossState.historyLossState.metaLossState.__getitem__()LossState.__contains__()LossState.get()LossState.cache_losses()LossState.update_meta()LossState.register_target()LossState.register_targets()LossState.set_weight()LossState.set_weights()LossState.get_weight()LossState.get_effective_weight()LossState.mark_compilable()LossState.compile_aggregate()LossState.reset_compiled_aggregate()LossState.log()LossState.new_entry()LossState.get_history()LossState.aggregate()LossState.get_loss()LossState.active_parameters()LossState.refresh_loss_leaves()LossState.reset_caches()LossState.restore_loss_leaf_grads()LossState.run()LossState.step()LossState.get_breakdown()LossState.get_group_totals()LossState.format_breakdown()LossState.summary()LossState.to()LossState.clear()LossState.clear_history()LossState.__init__()
LoggerTargetDataTargetDataTarget.nameDataTarget._modelDataTarget._dataDataTarget._scalerDataTarget.verboseDataTarget.nameDataTarget.__init__()DataTarget.modelDataTarget.dataDataTarget.scalerDataTarget.has_modelDataTarget.get_fcalc()DataTarget.get_fcalc_scaled()DataTarget.get_F_calc_scaled()DataTarget.get_rfactor()
ModelTarget- Subpackages
- torchref.refinement.optimizers package
- torchref.refinement.targets package
TargetModelTargetDataTargetgaussian_nll()von_mises_nll()adp_similarity_nll()XrayTargetGaussianXrayTargetMaximumLikelihoodXrayTargetLeastSquaresXrayTargetcreate_xray_target()DifferenceXrayTargetPhaseInformedDifferenceTargetRiceDifferenceTargetTaylorCorrectedDifferenceTargetGeometryTargetBondTargetAngleTargetTorsionTargetPlanarityTargetChiralTargetNonBondedTargetNonBondedHTargetRamachandranTargetADPTargetADPSimilarityTargetRigidBondTargetADPEntropyTargetADPLocalityTargetCombinedTargetsTotalGeometryTargetTotalADPTargetForceFieldTargetAmberTargetOccupancyFloorDiagnosticNegativeDensityPenaltyDisplacementRegularizerDifferenceAmplitudeRegularizerSampledMLPhaseTargetSampledMLDifferenceTargetcreate_sampled_ml_target()create_sampled_ml_difference_target()RealSpaceTargetRealSpaceCorrelationTargetRealSpaceDifferenceTargetRealSpaceExtrapolatedTargetCoordinateSimilarityTarget- Subpackages
- Submodules
- torchref.refinement.weighting package
- Submodules
- torchref.restraints package
- Classes
RestraintsResidueIteratorRestraintBuilderRestraintBuilder._indicesRestraintBuilder._referencesRestraintBuilder._sigmasRestraintBuilder._countRestraintBuilder.restraint_typeRestraintBuilder.atom_columnsRestraintBuilder.n_atomsRestraintBuilder.__init__()RestraintBuilder.reset()RestraintBuilder.process_residue()RestraintBuilder.finalize()RestraintBuilder.count
BondRestraintBuilderAngleRestraintBuilderTorsionRestraintBuilderPlaneRestraintBuilderChiralRestraintBuilderChiralRestraintBuilder._ideal_volumesChiralRestraintBuilder.ideal_volumeChiralRestraintBuilder.sigmaChiralRestraintBuilder.restraint_typeChiralRestraintBuilder.atom_columnsChiralRestraintBuilder.n_atomsChiralRestraintBuilder.__init__()ChiralRestraintBuilder.reset()ChiralRestraintBuilder.process_residue()ChiralRestraintBuilder.finalize()
InterResidueBondBuilderInterResidueBondBuilder._indicesInterResidueBondBuilder._referencesInterResidueBondBuilder._sigmasInterResidueBondBuilder.__init__()InterResidueBondBuilder.reset()InterResidueBondBuilder.process_peptide_bond()InterResidueBondBuilder.process_disulfide_bond()InterResidueBondBuilder.finalize()InterResidueBondBuilder.count
InterResidueAngleBuilderInterResidueTorsionBuilderInterResidueTorsionBuilder.__init__()InterResidueTorsionBuilder.reset()InterResidueTorsionBuilder.process_peptide_torsions()InterResidueTorsionBuilder.process_disulfide_torsions()InterResidueTorsionBuilder.finalize_phi()InterResidueTorsionBuilder.finalize_psi()InterResidueTorsionBuilder.finalize_omega()InterResidueTorsionBuilder.finalize_disulfide()
InterResiduePlaneBuilderget_library_manager()- Submodules
- torchref.restraints.builders module
- torchref.restraints.builders_fast module
PreprocessedPDBPreprocessedCIFRestraintBuilderBondRestraintBuilderAngleRestraintBuilderTorsionRestraintBuilderPlaneRestraintBuilderChiralRestraintBuilderbuild_all_restraints()PreprocessedLinkDataInterResidueBondBuilderInterResidueAngleBuilderInterResidueTorsionBuilderInterResiduePlaneBuilderResidueIterator
- torchref.restraints.builders_numba module
- torchref.restraints.hydrogen_topology module
- torchref.restraints.library module
- torchref.restraints.neighbor_search module
- torchref.restraints.ramachandran module
- torchref.restraints.restraints module
- torchref.restraints.restraints_helper module
validate_restraint_data()read_cif()split_respecting_quotes()find_cif_file_in_library()read_link_definitions()build_restraints_bondlength()build_restraints_angles()build_restraints_torsion()build_restraints_planes()build_restraints()calculate_restraints_bondlength()calculate_restraints_angles()calculate_restraints_torsion()calculate_restraints_all()read_for_component()read_comp_list()
- torchref.scaling package
- Classes
ScalerScaler.deviceScaler.nbinsScaler.__init__()Scaler.modelScaler.set_model_and_data()Scaler.initialize()Scaler.compute_fcalc()Scaler.calc_initial_scale()Scaler.fit_anisotropy()Scaler.setup_solvent()Scaler.fit_all_scales()Scaler.screen_solvent_params()Scaler.refine_lbfgs()Scaler.rfactor()Scaler.bin_wise_rfactor()Scaler.get_binwise_mean_intensity()Scaler.state_dict()Scaler.load_state_dict()
ScalerBaseScalerBase.deviceScalerBase.nbinsScalerBase.__init__()ScalerBase.set_data()ScalerBase.initialize()ScalerBase.hklScalerBase.calc_initial_scale()ScalerBase.setup_anisotropy_correction()ScalerBase.anisotropy_correction()ScalerBase.fit_anisotropy()ScalerBase.set_solvent_model()ScalerBase.setup_binwise_solvent_scale()ScalerBase.fit_all_scales()ScalerBase.fit_simple()ScalerBase.get_scale()ScalerBase.rfactor()ScalerBase.bin_wise_rfactor()ScalerBase.setup_bin_wise_bfactor()ScalerBase.bin_wise_bfactor_correction()ScalerBase.get_binwise_mean_intensity()ScalerBase.screen_solvent_params()ScalerBase.refine_lbfgs()ScalerBase.estimate_sigma_eff()ScalerBase.forward()ScalerBase.state_dict()ScalerBase.load_state_dict()ScalerBase.save_state()ScalerBase.load_state()
SolventModelSolventModel.modelSolventModel.deviceSolventModel.verboseSolventModel.float_typeSolventModel.solvent_radiusSolventModel.erosion_radiusSolventModel.optimize_phaseSolventModel.log_k_solventSolventModel.b_solventSolventModel.phase_offsetSolventModel.__init__()SolventModel.get_solvent_mask()SolventModel.update_solvent()SolventModel.smooth_solvent_mask()SolventModel.get_rec_solvent()SolventModel.forward()SolventModel.parameters()
CollectionScalerCollectionScaler.__init__()CollectionScaler.initialize()CollectionScaler.get_mixed_solvent_raw()CollectionScaler.forward_mixed()CollectionScaler.refine_lbfgs_joint()CollectionScaler.screen_solvent_params_joint()CollectionScaler.update_all_solvent()CollectionScaler.invalidate_solvent_cache()CollectionScaler.component_solvent_models
- Submodules
- torchref.scripts package
- torchref.symmetry package
CellCell.__init__()Cell.reset_cache()Cell.detach()Cell.clone()Cell.deviceCell.dtypeCell.dataCell.requires_gradCell.aCell.bCell.cCell.alphaCell.betaCell.gammaCell.fractional_matrixCell.inv_fractional_matrixCell.volumeCell.reciprocal_basis_matrixCell.compute_grid_size()Cell.tolist()Cell.fractional_to_cartesian()Cell.cartesian_to_fractional()Cell.__repr__()Cell.__getitem__()Cell.__len__()
SpaceGroupSpaceGroup.matricesSpaceGroup.translationsSpaceGroup.n_opsSpaceGroup.__init__()SpaceGroup.n_opsSpaceGroup.nameSpaceGroup.hmSpaceGroup.xhmSpaceGroup.numberSpaceGroup.gemmiSpaceGroup.point_groupSpaceGroup.crystal_systemSpaceGroup.centrosymmetricSpaceGroup.dtypeSpaceGroup.deviceSpaceGroup.spacegroupSpaceGroup.space_groupSpaceGroup.space_group_nameSpaceGroup.space_group_numberSpaceGroup.short_name()SpaceGroup.operations()SpaceGroup.apply()SpaceGroup.apply_to_hkl()SpaceGroup.expand_coords_to_P1()SpaceGroup.forward()SpaceGroup.get_grid_requirements()SpaceGroup.check_grid_compatibility()SpaceGroup.suggest_grid_size()SpaceGroup.__hash__()SpaceGroup.__eq__()SpaceGroup.copy()
spacegroup_to_str()get_symmetry_operations()get_operations_as_tensors()is_same_spacegroup()get_point_group()get_crystal_system()is_centrosymmetric()n_operations()SymmetryMapSymmetry()MapSymmetryDirectReciprocalSymmetry()ReciprocalSymmetryGridReciprocalSymmetryGrid.space_groupReciprocalSymmetryGrid.grid_shapeReciprocalSymmetryGrid.symmetryReciprocalSymmetryGrid.n_opsReciprocalSymmetryGrid.__init__()ReciprocalSymmetryGrid.apply_to_indices()ReciprocalSymmetryGrid.get_phase_shift()ReciprocalSymmetryGrid.get_symmetry_mate()ReciprocalSymmetryGrid.get_all_symmetry_mates()ReciprocalSymmetryGrid.symmetry_average()ReciprocalSymmetryGrid.expand_to_p1()ReciprocalSymmetryGrid.apply_friedel()ReciprocalSymmetryGrid.is_systematic_absence()ReciprocalSymmetryGrid.is_centric()ReciprocalSymmetryGrid.get_epsilon()ReciprocalSymmetryGrid.forward()ReciprocalSymmetryGrid.__call__()ReciprocalSymmetryGrid.get_symmetry_info()
expand_hkl()complete_hkl()reduce_hkl()canonicalize_hkl()expand_reflections()expand_reciprocal_grid()get_symmetry_grid_requirements()check_grid_compatibility()recommend_grid_size()find_fft_friendly_size()is_fft_friendly()calculate_optimal_grid_size()- Submodules
- torchref.utils package
ParameterFingerprintCachedForwardMixinDeviceMixinDeviceMovementMixinresolve_device()TensorMasksTensorDictModuleReferencesave_map()sanitize_pdb_dataframe()parse_phenix_selection()create_selection_mask()DebugMixinprint_module_summary()StatEntrystat()filter_stats()flatten_stats()format_stats_table()HyperparameterMixinHyperparameterMixin.__init__()HyperparameterMixin.register_hyperparameter()HyperparameterMixin.get_hyperparameter()HyperparameterMixin.set_hyperparameter()HyperparameterMixin.hyperparameters()HyperparameterMixin.named_hyperparameters()HyperparameterMixin.hyperparameter_state_dict()HyperparameterMixin.load_hyperparameter_state_dict()HyperparameterMixin.hyperparameter_dict()HyperparameterMixin.print_hyperparameters()
convert_to_serializable()gradnorm()validate_loss()NonFiniteLossErrorreset_diagnostic_budget()collect_loss_leaves()- Submodules
- torchref.utils.autograd_introspection module
- torchref.utils.autograd_ops module
- torchref.utils.caching module
- torchref.utils.debug_utils module
- torchref.utils.device_mixin module
- torchref.utils.device_resolution module
- torchref.utils.gradnorm module
- torchref.utils.hyperparameters module
- torchref.utils.loss_validation module
- torchref.utils.pse module
- torchref.utils.serialization module
- torchref.utils.stats module
- torchref.utils.timing module
- torchref.utils.utils module