torchref.io package
I/O module for crystallographic data files.
This module provides: - Dataset classes for handling reflection data - Format-specific readers and writers (MTZ, PDB, CIF) - Automatic format detection via DataRouter
High-level API
Load a single dataset:
from torchref.io import ReflectionData
data = ReflectionData(verbose=1)
data.load_mtz('structure.mtz')
Multi-dataset handling:
from torchref.io import DatasetCollection
collection = DatasetCollection()
collection.add_dataset('native', native_data)
collection.add_dataset('derivative', derivative_data)
Direct format access:
from torchref.io import mtz
reader = mtz.read('data.mtz')
data_dict, cell, spacegroup = reader()
- class torchref.io.CrystalDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)[source]
Bases:
DeviceMixinBase dataclass for crystallographic datasets.
Defines all possible tensor fields (optional) and handles device management and serialization. Subclasses add domain-specific methods.
This lightweight design enables scaling to 1000s of datasets without the overhead of torch.nn.Module.
- Parameters:
device (torch.device) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.
verbose (int) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.
Examples
Basic usage:
data = CrystalDataset(device='cuda') data.hkl = torch.tensor([[1, 0, 0], [0, 1, 0]]) data.cpu() # Move all tensors to CPU
- save_state(path)[source]
Save dataset state to file.
- Parameters:
path (str) – Output file path.
Examples
Save to file:
data.save_state('reflection_data.pt')
- classmethod load_state(path, device=device(type='cpu'))[source]
Load dataset state from file.
- Parameters:
- Returns:
Loaded dataset.
- Return type:
Examples
Load from file:
data = ReflectionData.load_state('reflection_data.pt', device='cuda')
- property spacegroup_hm: str | None
Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)
- class torchref.io.ReflectionData(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)[source]
Bases:
CrystalDataset,DebugMixinContainer for crystallographic reflection data.
This class handles loading, processing, and accessing reflection data including Miller indices, structure factor amplitudes, intensities, and R-free flags. All data is stored as PyTorch tensors for GPU acceleration.
- Parameters:
- hkl
Miller indices of shape (N, 3), dtype int32.
- Type:
- F
Structure factor amplitudes of shape (N,), dtype float32.
- Type:
- F_sigma
Amplitude uncertainties of shape (N,), dtype float32.
- Type:
- I
Intensities of shape (N,), dtype float32.
- Type:
- I_sigma
Intensity uncertainties of shape (N,), dtype float32.
- Type:
- rfree_flags
R-free test set flags of shape (N,), dtype bool.
- Type:
- cell
Unit cell parameters [a, b, c, alpha, beta, gamma].
- Type:
- resolution
Resolution per reflection in Ångströms of shape (N,).
- Type:
Examples
Load reflection data from an MTZ file:
data = ReflectionData(verbose=1, device='cuda') data.load_mtz('data.mtz') print(f"Loaded {len(data.hkl)} reflections") print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
- source: ReflectionData | None = None
- __post_init__()[source]
Initialize non-dataclass attributes after dataclass init.
This is called automatically after the dataclass __init__.
- load(reader)[source]
Load reflection data using a data reader.
- Parameters:
reader (callable) – Data reader object that returns (data_dict, cell, spacegroup) when called. Can be MTZ, ReflectionCIFReader, or other compatible reader.
- Returns:
Self, for method chaining.
- Return type:
- Raises:
ValueError – If unit cell parameters are missing or no amplitude/intensity data found.
- classmethod from_tensors(hkl, F, F_sigma, cell, spacegroup, rfree_flags=None, device=device(type='cpu'), verbose=1)[source]
Construct ReflectionData directly from tensors.
- Parameters:
hkl (torch.Tensor) – Miller indices of shape (N, 3).
F (torch.Tensor) – Structure factor amplitudes of shape (N,).
F_sigma (torch.Tensor) – Amplitude uncertainties of shape (N,).
cell (Cell) – Unit cell parameters.
spacegroup (SpaceGroup) – Space group.
rfree_flags (torch.Tensor, optional) – R-free flags of shape (N,), dtype bool. If None, flags are generated automatically (2% free fraction).
device (str, optional) – Device for tensors. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
- Returns:
Fully initialized reflection data with all cleanup applied.
- Return type:
- load_mtz(path, column_names=None)[source]
Load reflection data from MTZ file.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- load_cif(path, data_block=None)[source]
Load reflection data from CIF file.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- static list_cif_data_blocks(path)[source]
List all data blocks available in a CIF file without loading data.
Useful for multi-dataset CIF files to inspect available blocks before loading a specific one.
- Parameters:
path (str) – Path to CIF file.
- Returns:
Names of all data blocks in the CIF file.
- Return type:
Examples
List and load a specific data block:
blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif') print(blocks) # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...] data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
- get_bins(n_bins=20, min_per_bin=100)[source]
Create resolution bins with approximately equal reflection counts.
- Parameters:
- Returns:
bin_indices (torch.Tensor) – Tensor of shape (N,) with bin index for each reflection.
n_bins (int) – Actual number of bins created (may be less than target for small datasets).
- Return type:
- mean_res_per_bin()[source]
Calculate mean resolution for each bin.
- Returns:
Mean resolution for each bin in Ångströms.
- Return type:
- Raises:
ValueError – If bins have not been created yet.
- mean_F_per_bin()[source]
Calculate mean structure factor amplitude per resolution bin.
- Returns:
Mean F per bin of shape (n_bins,).
- Return type:
- Raises:
ValueError – If bins have not been created yet.
- mean_sigma_per_bin()[source]
Calculate mean structure factor uncertainty per resolution bin.
- Returns:
Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.
- Return type:
torch.Tensor or None
- Raises:
ValueError – If bins have not been created yet.
- regenerate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100, seed=None, force=False)[source]
Regenerate R-free flags with resolution-stratified sampling.
- Parameters:
free_fraction (float, optional) – Fraction of reflections to mark as free. Default is 0.02 (2%).
n_bins (int, optional) – Target number of resolution bins. Default is 20.
min_per_bin (int, optional) – Minimum reflections per resolution bin. Default is 100.
seed (int, optional) – Random seed for reproducibility. Default is None.
force (bool, optional) – If True, regenerate even if flags already exist. Default is False.
Examples
Generate R-free flags:
# Generate 2% free reflections with reproducible seed data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42) # Generate 5% free with 10 bins, overwriting existing data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
- get_structure_factors(as_complex=False)[source]
Get structure factors, optionally as complex numbers.
- Parameters:
as_complex (bool, optional) – If True and phases available, return F*exp(i*phi). Default is False.
- Returns:
Structure factor amplitudes or complex structure factors.
- Return type:
- Raises:
ValueError – If no amplitude data is loaded.
- get_structure_factors_with_sigma()[source]
Get structure factor amplitudes and their uncertainties.
- Returns:
F (torch.Tensor) – Structure factor amplitudes of shape (N,).
F_sigma (torch.Tensor or None) – Uncertainties of shape (N,), or None if not available.
- Raises:
ValueError – If no amplitude data is loaded.
- Return type:
Examples
Get amplitudes with uncertainties:
F, sigma_F = data.get_structure_factors_with_sigma() if sigma_F is not None: weighted_residual = (F_obs - F_calc) / sigma_F
- get_hkl()[source]
Return Miller indices for valid reflections.
- Returns:
Miller indices of shape (N, 3), dtype int32.
- Return type:
- Raises:
ValueError – If no Miller indices are loaded.
- filter_by_resolution(d_min=None, d_max=None)[source]
Filter reflections by resolution range.
Adds a boolean mask to self.masks for the specified resolution range.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- get_mask()[source]
Return combined mask from all active filters.
- Returns:
Boolean mask combining all filter conditions.
- Return type:
- cut_res(highres=None, lowres=None)[source]
Filter reflections by resolution range.
Alias for filter_by_resolution with more intuitive naming.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
Examples
Filter by resolution:
# Keep reflections between 50 Å and 1.5 Å filtered = data.cut_res(highres=1.5, lowres=50.0) # Keep only high-resolution data (< 2 Å) high_res = data.cut_res(highres=1.0, lowres=2.0)
- get_rfree_masks()[source]
Get boolean masks for work and test (free) sets.
- Returns:
work_mask (torch.Tensor or None) – Boolean tensor for work set (flag != 0).
test_mask (torch.Tensor or None) – Boolean tensor for test/free set (flag == 0). Both are None if no R-free flags are available.
- Return type:
Examples
Separate work and test sets:
work_mask, test_mask = data.get_rfree_masks() if work_mask is not None: F_work = data.F[work_mask] F_test = data.F[test_mask]
- get_work_set()[source]
Get structure factors for the work set (R-free flag != 0).
- Returns:
F_work (torch.Tensor) – Structure factors for work set.
sigma_work (torch.Tensor or None) – Uncertainties for work set, or None if not available.
- Return type:
Notes
Returns full dataset with warning if no R-free flags available.
- get_test_set()[source]
Get structure factors for the test set (R-free flag == 0).
- Returns:
F_test (torch.Tensor) – Structure factors for test/free set.
sigma_test (torch.Tensor or None) – Uncertainties for test set, or None if not available.
- Raises:
ValueError – If no R-free flags are available.
- Return type:
- get_max_res()[source]
Return maximum resolution (lowest d-spacing).
- Returns:
Maximum resolution in Ångströms.
- Return type:
- get_min_res()[source]
Return minimum resolution (highest d-spacing).
- Returns:
Minimum resolution in Ångströms.
- Return type:
- __len__()[source]
Return number of reflections.
- Returns:
Number of reflections in the dataset.
- Return type:
- property d_min: float | None
Return maximum resolution (lowest d-spacing).
- Returns:
Maximum resolution in Ångströms.
- Return type:
- __repr__()[source]
Return string representation.
- Returns:
Summary of reflection data including count, sources, and resolution.
- Return type:
- get_valid_mask()[source]
Return the combined validity mask for all reflections.
This is the mask used to filter reflections in forward(). True indicates a valid (included) reflection, False indicates an excluded one.
- Returns:
Boolean mask of shape (N,) where N is the total number of reflections. True = valid/included, False = invalid/excluded.
- Return type:
Examples
Check validity:
mask = data.get_valid_mask() print(f"{mask.sum()} of {len(mask)} reflections are valid")
- data_indexed()[source]
Return reflection data as indexed (filtered) tensors.
This method filters out invalid reflections and returns smaller tensors containing only valid data. Useful for operations that don’t support MaskedTensors or for writing output files.
- Returns:
hkl (torch.Tensor) – Miller indices of shape (M, 3) where M is number of valid reflections.
F (torch.Tensor) – Structure factor amplitudes of shape (M,).
F_sigma (torch.Tensor or None) – Uncertainties of shape (M,) or None.
rfree_flags (torch.Tensor or None) – R-free flags of shape (M,) or None.
- Return type:
See also
forwardMain method returning MaskedTensors.
Examples
Get indexed data for file writing:
hkl, F, sigma, rfree = data.data_indexed() F_np = F.cpu().numpy() # Safe for writing to files
- __call__(mask=True, scale=True)[source]
Return core reflection data with MaskedTensors for F and sigma.
F and F_sigma are returned as MaskedTensors which keep all reflections but mark invalid ones as masked. Aggregation operations (sum, mean, etc.) automatically skip masked values. HKL and rfree_flags remain regular tensors.
- Parameters:
- Returns:
hkl (torch.Tensor) – Miller indices of shape (N, 3). Full size, unfiltered.
F (MaskedTensor) – Structure factor amplitudes of shape (N,) with invalid reflections masked.
F_sigma (MaskedTensor or None) – Uncertainties of shape (N,) with invalid reflections masked, or None.
rfree_flags (torch.Tensor or None) – R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.
- Return type:
Notes
MaskedTensors:
Are PyTorch tensors with an associated boolean mask
Aggregations (sum, mean, etc.) skip masked values automatically
Element-wise operations preserve the mask
Use .get_data() and .get_mask() to access underlying data
Use .to_tensor(fill_value) to convert back to regular tensor
Note: MaskedTensor is in prototype stage in PyTorch
Loss functions and targets extract valid data from MaskedTensors before computation to work correctly with complex F_calc values.
Examples
Access reflection data with MaskedTensors:
hkl, F, sigma, rfree = data() print(F.shape) # Full shape print(F.sum()) # Only sums valid (unmasked) values # Access underlying data valid_mask = F.get_mask() F_values = F.get_data()[valid_mask]
- data_fill_masked(mode='mean')[source]
Return data tensors with missing or flagged reflections filled with.
- args:
- mode: str, optional
‘mean’ : fill missing/flagged with mean of present data ‘zero’ : fill missing/flagged with zero
- __getitem__(key)[source]
Index into the reflection dataset.
- Parameters:
key (torch.Tensor) – Boolean mask or integer indices for selection.
- Returns:
New ReflectionData object with selected reflections.
- Return type:
- __select__(indices, op=None)[source]
Select reflections by boolean mask or integer indices.
Iterates over all dataclass fields generically, so new tensor fields are handled automatically.
- Parameters:
indices (torch.Tensor) – Boolean mask of shape (N,) or integer indices for selection.
op (str, optional) – Operation name for tracking purposes.
- Returns:
New ReflectionData object with selected reflections.
- Return type:
- sanitize_F()[source]
Remove invalid values from structure factors.
Adds a mask to filter out NaN, Inf, and non-positive values from F and F_sigma.
- validate_hkl(hkl_ref)[source]
Expand dataset to match a reference HKL set.
Reorders and expands the current dataset to match the reference HKL set. Reflections present in the reference but missing from this dataset are filled with placeholder values and masked out. This ensures all datasets aligned to the same reference have identical shapes and can be processed together without data loss from intersection operations.
- Parameters:
hkl_ref (torch.Tensor) – Reference Miller indices tensor of shape (N, 3), dtype int32. This defines the canonical HKL ordering for all aligned datasets.
- Returns:
Self, modified in-place with expanded arrays matching hkl_ref.
- Return type:
Notes
After calling this method: - self.hkl will equal hkl_ref exactly - All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded - Missing reflections are filled with 0 (or appropriate defaults) - A mask ‘hkl_present’ is added marking which reflections have real data - forward() will return MaskedTensors that skip missing reflections
This approach avoids the problem where intersecting many datasets with different outliers/missing reflections causes exponential data loss.
Examples
Align multiple datasets to a common HKL set:
reference_hkl = data1.hkl.clone() data1.validate_hkl(reference_hkl) data2.validate_hkl(reference_hkl) # Now data1 and data2 have identical shapes assert data1.hkl.shape == data2.hkl.shape
- find_outliers(model, scaler, z_threshold=4.0)[source]
Identify outlier reflections based on log-ratio distribution.
Uses the fact that log(F_obs) - log(F_calc) should be normally distributed. Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.
- Parameters:
- Returns:
Boolean mask where True indicates outliers.
- Return type:
- get_log_ratio(model, scaler)[source]
Compute log-ratio between observed and calculated structure factors.
- Parameters:
- Returns:
Log-ratio values: log(F_obs) - log(F_calc).
- Return type:
- get_outlier_statistics()[source]
Get statistics about flagged outliers.
- Returns:
Dictionary containing: - n_outliers : int - n_total : int - fraction_outliers : float - detection_params : dict or None - outlier_resolution_stats : dict (if resolution available)
- Return type:
- unpack_one()[source]
Unpack one level of source.
Does not recurse fully and does not flag.
- Returns:
Parent source or self if no source.
- Return type:
- flag_suspicious_sigma(z_threshold=5.0)[source]
Flag sigma values that deviate significantly from expected distribution.
Sigma values from a detector should follow a log-normal distribution. Values with z-scores beyond threshold are flagged as suspicious.
- Parameters:
z_threshold (float, optional) – Z-score threshold for flagging outliers. Default is 5.0.
- dump()[source]
Dump all reflection data to console for debugging.
Prints type, shape, and device information for all attributes.
- write_mtz(fname, fcalc=None, model_ft=None)[source]
Write reflection data to MTZ file with optional map coefficients.
- Parameters:
fname (str) – Output MTZ filename.
fcalc (torch.Tensor, optional) – Complex calculated structure factors of shape (N,). If provided, computes phases and map coefficients.
model_ft (ModelFT, optional) – ModelFT object to compute fcalc if not provided.
Notes
- The MTZ file will contain canonical column names:
FP, SIGFP: Observed amplitudes and uncertainties
I, SIGI: Observed intensities and uncertainties (if available)
FreeR_flag: R-free test set flags
FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)
DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)
- Map coefficients are computed as:
2mFo-DFc: 2*Fo - Fc (filled to resolution limit)
mFo-DFc: Fo - Fc
Examples
Write MTZ with map coefficients:
data = ReflectionData().load_mtz('observed.mtz') model = Model().load_pdb('model.pdb') model_ft = ModelFT(model, data.cell, data.spacegroup) fcalc = model_ft.forward(data.hkl) data.write_mtz('output.mtz', fcalc=fcalc)
- property centric
Get boolean mask for centric reflections (full size, unfiltered).
Calculates it if not already present. Returns unfiltered centric flags matching the full HKL array size, consistent with how forward() returns full-size arrays when using MaskedTensors.
- Returns:
Boolean tensor of shape (N,) where N is total reflections. True indicates centric reflection, False indicates acentric.
- Return type:
torch.Tensor or None
- calc_patterson(grid_size=None, grid_sampling=1)[source]
Calculate Patterson map of the dataset.
The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw)) is computed via inverse FFT of F². Data is expanded to P1 symmetry using only observed reflections (no filling of missing data).
- Parameters:
- Returns:
Real-valued Patterson map of shape (Nx, Ny, Nz). Origin is at grid position [0, 0, 0].
- Return type:
- possible_hkl()[source]
Generate all possible HKL indices within the resolution limit.
- Returns:
Tensor of shape (M, 3) containing all possible Miller indices within the resolution limit defined by self.resolution.
- Return type:
- remap(new_hkl, index_mapping, phase_shifts=None, spacegroup=None, op_name='remap')[source]
Create new ReflectionData with remapped HKL set and data.
This is the core method for index-based transformations of reflection data. It handles remapping all tensor fields based on an index mapping, with support for missing reflections (indicated by -1 in index_mapping).
- Parameters:
new_hkl (torch.Tensor, shape (M, 3)) – New Miller indices.
index_mapping (torch.Tensor, shape (M,), dtype int64) – Maps new indices to original:
new[i] = old[index_mapping[i]]Values of -1 indicate missing reflections (filled with defaults).phase_shifts (torch.Tensor, optional, shape (M,)) – Phase offsets to apply (e.g., from symmetry translations).
spacegroup (str, int, gemmi.SpaceGroup, or None) – New spacegroup. If None, keeps original.
op_name (str) – Operation name for provenance tracking.
- Returns:
New object with remapped data. Missing reflections get: - 0.0 for F, I, phase, fom - 1.0 for F_sigma, I_sigma (conservative uncertainty) - True for masks[‘missing’]
- Return type:
- fill(d_min=None)[source]
Fill missing reflections within resolution limit.
Generates all possible reflections for the current spacegroup within the resolution limit, identifies which are missing, and creates a complete dataset. Missing reflections are filled with default values.
- Parameters:
d_min (float, optional) – High resolution limit in Angstroms. If None, uses the minimum resolution from the current dataset.
- Returns:
New ReflectionData with complete set of reflections. Missing reflections have: - F, I, phase, fom = 0.0 - F_sigma, I_sigma = 1.0 - masks[‘missing’] = True
- Return type:
Examples
Fill to completeness:
data = ReflectionData().load_mtz('data.mtz') data_filled = data.fill(d_min=2.0) print(f"Original: {len(data)}, Filled: {len(data_filled)}")
- expand_to_p1(include_friedel=True, remove_absences=True)[source]
Expand reflection data from asymmetric unit to P1.
Applies all symmetry operations from the current space group to generate all symmetry-equivalent reflections. Returns a NEW ReflectionData object with expanded reflections; does not modify self.
- Parameters:
- Returns:
New ReflectionData object with: - All symmetry-equivalent reflections - spacegroup set to P1 - Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags) - Recalculated resolution values - Provenance tracking via source and last_op attributes
- Return type:
Notes
The expansion handles tensor fields as follows:
hkl: Symmetry operations applied, Friedel mates added, duplicates removed
F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original
phase: Indexed + phase shift from translation component
resolution: Recalculated from expanded hkl + cell
bin_indices: Cleared (invalidated by expansion)
Examples
Expand to P1 for calculations:
# Load data in original space group data = ReflectionData().load_mtz('data.mtz') print(f"Original: {len(data)} reflections, {data.spacegroup}") # Expand to P1 data_p1 = data.expand_to_p1() print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}") # Without Friedel mates (for anomalous data) data_p1_anom = data.expand_to_p1(include_friedel=False)
- reduce_to_spacegroup(spacegroup, include_friedel=True, aggregation='mean')[source]
Reduce P1 reflection data to asymmetric unit of a target spacegroup.
This is the inverse of expand_to_p1(). Takes reflection data in P1 and merges symmetry-equivalent reflections into single ASU reflections using the specified aggregation function.
- Parameters:
spacegroup (str, int, or gemmi.SpaceGroup) – Target space group specification.
include_friedel (bool, default True) – If True, also merge Friedel mates when reducing.
aggregation (str, default 'mean') – Aggregation function for merging equivalent reflections: - ‘mean’: Average values (default, good for amplitudes) - ‘sum’: Sum values - ‘first’: Take first valid value (no averaging)
- Returns:
New ReflectionData with merged reflections in the target spacegroup.
- Return type:
Notes
Field handling during reduction:
F, I: Aggregated using specified function
F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for ‘mean’
phase: Aggregated via complex averaging (handles phase wrapping)
fom: Weighted by amplitude during phase averaging
rfree_flags: OR operation (True if any equivalent is True)
Examples
Reduce symmetry-expanded data:
# Expand to P1 for calculations, then reduce back data_p1 = data.expand_to_p1() # ... modify F_p1 ... data_merged = data_p1.reduce_to_spacegroup('P21') # Reduce with sum instead of mean data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
- canonicalize(include_friedel=True)[source]
Return new ReflectionData with HKL in standard CCP4 ASU form.
Remaps all Miller indices to the canonical CCP4 asymmetric unit representative using
gemmi.ReciprocalAsu, adjusts phases accordingly, and sorts reflections lexicographically by (h, k, l).- Parameters:
include_friedel (bool, default True) – Whether Friedel mates are considered equivalent.
- Returns:
New object with canonicalized, sorted Miller indices.
- Return type:
- get_scattering_vectors()[source]
Get scattering vectors (s-vectors) from hkl and cell.
- The s-vector for a reflection hkl is defined as:
s = B* @ hkl
where B* is the reciprocal basis matrix.
- Returns:
s_vectors – Reciprocal space vectors in Angstroms^-1, shape (N, 3).
- Return type:
- Raises:
ValueError – If hkl or cell is not available.
- get_radial_shells(n_shells=20, d_min=None, d_max=None)[source]
Create uniform radial shells in 1/d space for normalization.
This creates shells with uniform spacing in reciprocal space (1/d), different from get_bins() which creates equal-count bins.
- Parameters:
- Returns:
shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).
shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).
shell_indices (torch.Tensor) – Shell index for each reflection, shape (N,). Values -1 for out-of-range.
- Return type:
- fit_anisotropy(n_shells=20, d_min=None, d_max=None, n_iterations=100, verbose=None)[source]
Fit anisotropy correction parameters to minimize CV within shells.
Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell.
- Parameters:
n_shells (int) – Number of resolution shells for variance calculation.
d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.
d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.
n_iterations (int) – Number of optimization iterations.
verbose (bool, optional) – Print progress. If None, uses self.verbose.
- Returns:
U – Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,). Also stored in self.U_aniso.
- Return type:
- Raises:
ValueError – If no amplitude data is available.
- setup_anisotropy(U_aniso=None)[source]
Setup anisotropy correction parameters. :param U_aniso: Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).
If None, uses Initializes as zeros.
- apply_anisotropy_correction(U_aniso=None)[source]
Apply anisotropy correction to F² values.
- Parameters:
U_aniso (torch.Tensor, optional) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,). If None, uses self.U_aniso (must have called fit_anisotropy first).
- Returns:
F_corrected (torch.Tensor) – Anisotropy-corrected F values, shape (N,).
sigma_F_corrected (torch.Tensor) – Uncertainties of corrected F values, shape (N,).
- Raises:
ValueError – If no U parameters available and none provided.
- Return type:
- compute_e_values(n_shells=20, d_min=None, d_max=None, apply_anisotropy=True, fit_anisotropy=True, verbose=None)[source]
Compute E-values with optional anisotropy correction.
E-values are normalized structure factors where <E²> = 1 within each resolution shell. Anisotropy correction can be applied first to account for directional variation in diffraction.
- Parameters:
n_shells (int) – Number of resolution shells for normalization.
d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.
d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.
apply_anisotropy (bool) – If True, apply anisotropy correction before E-value calculation.
fit_anisotropy (bool) – If True and apply_anisotropy is True, fit anisotropy parameters. If False and apply_anisotropy is True, uses existing self.U_aniso.
verbose (bool, optional) – Print progress. If None, uses self.verbose.
- Returns:
E – E-values, shape (N,). Also stored in self.E. self.E_squared is also populated with E² values.
- Return type:
- Raises:
ValueError – If no amplitude data is available.
Examples
Compute E-values with automatic anisotropy correction:
data = ReflectionData().load_mtz('data.mtz') E = data.compute_e_values(n_shells=30) print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")
Compute E-values without anisotropy correction:
E = data.compute_e_values(apply_anisotropy=False)
- get_corrected_data()[source]
Get scaled amplitude and uncertainty tensors.
Applies the current scale factor (from self.log_scale) to F and F_sigma.
- Returns:
Scaled F and F_sigma tensors. If log_scale is None, returns original.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- parameters()[source]
Get list of learnable parameters for optimization.
- Returns:
iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.
- Return type:
iter[Parameter]
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)
- class torchref.io.DatasetCollection(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)[source]
Bases:
CrystalDatasetContainer for multiple related crystal datasets.
All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out.
- Parameters:
- hkl
Common HKL set for all datasets.
- Type:
Examples
from torchref.io import DatasetCollection, ReflectionData collection = DatasetCollection(device='cuda') native = ReflectionData().load_mtz('native.mtz') derivative = ReflectionData().load_mtz('derivative.mtz') collection.add_dataset('native', native, set_as_reference=True) collection.add_dataset('derivative', derivative) for name, dataset in collection: print(f"{name}: {len(dataset)} reflections") # Access by name native_F = collection['native'].F
- add_dataset(name, dataset, set_as_reference=False)[source]
Add a dataset to the collection.
- Parameters:
name (str) – Identifier for this dataset.
dataset (ReflectionData) – The dataset to add.
set_as_reference (bool, optional) – If True, use this dataset’s HKL as the reference. Default is False, but the first dataset added automatically becomes the reference.
- Returns:
Self, for method chaining.
- Return type:
- Raises:
ValueError – If a dataset with the same name already exists.
Examples
collection = DatasetCollection() collection.add_dataset('native', native_data, set_as_reference=True) collection.add_dataset('derivative', derivative_data)
- property datasets: Dict[str, ReflectionData]
Access all datasets as a dictionary.
- __getitem__(name)[source]
Get dataset by name.
- __iter__()[source]
Iterate over (name, dataset) pairs in order of addition.
- Yields:
tuple of (str, ReflectionData) – Name and dataset for each dataset in collection.
- scale()[source]
Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps.
The collection instance, allowing for method chaining.
- Raises:
ValueError – If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh…
Notes
The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. “”” Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset.
- Returns:
for method chaining.
- Return type:
self
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)
- class torchref.io.FcalcDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)[source]
Bases:
CrystalDatasetDataset for storing calculated structure factors.
Provides a lightweight container for Fcalc values with: - Cell and spacegroup information (using torchref.symmetry types) - HKL indices and resolution - Complex Fcalc with amplitude/phase decomposition - MTZ export capability
This class inherits from CrystalDataset and overrides the spacegroup field to store torchref.symmetry.SpaceGroup instead of gemmi.SpaceGroup.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices of shape (N, 3).
resolution (torch.Tensor, optional) – Resolution per reflection of shape (N,).
cell (Cell, optional) – Unit cell object.
spacegroup (SpaceGroup, optional) – Space group object (torchref.symmetry.SpaceGroup).
fcalc (torch.Tensor, optional) – Complex structure factors of shape (N,).
fcalc_amp (torch.Tensor, optional) – Amplitudes |Fcalc| of shape (N,).
fcalc_phase (torch.Tensor, optional) – Phases in radians of shape (N,).
device (torch.device) – Device for tensors.
Examples
Create from cell and resolution:
from torchref.io.datasets import FcalcDataset dataset = FcalcDataset.from_cell_and_resolution( cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0], spacegroup='P212121', d_min=2.0, ) # Set Fcalc values (complex tensor) fcalc = torch.randn(len(dataset), dtype=torch.complex64) dataset.set_fcalc(fcalc) # Write to MTZ dataset.write_mtz('output.mtz')
- spacegroup: SpaceGroup | None = None
- static from_cell_and_resolution(cell, spacegroup, d_min=2.0, d_max=None, device=device(type='cpu'), dtype=torch.float32)[source]
Create FcalcDataset with HKL generated to given resolution.
- Parameters:
cell (torch.Tensor, list, or Cell) – Unit cell [a, b, c, alpha, beta, gamma] or Cell object.
spacegroup (SpaceGroupLike) – Space group (str, int, gemmi.SpaceGroup, or torchref.symmetry.SpaceGroup).
d_min (float, optional) – High resolution limit in Angstroms. Default is 2.0.
d_max (float, optional) – Low resolution limit in Angstroms. If provided, reflections with d-spacing > d_max are removed.
device (torch.device) – Target device.
dtype (torch.dtype) – Float dtype for tensors.
- Returns:
New dataset with HKL and resolution populated.
- Return type:
Examples
from torchref.symmetry import Cell, SpaceGroup cell = Cell([50.0, 60.0, 70.0, 90.0, 90.0, 90.0]) sg = SpaceGroup('P212121') dataset = FcalcDataset.from_cell_and_resolution( cell=cell, spacegroup=sg, d_min=2.0, ) print(f"Generated {len(dataset)} reflections")
- set_fcalc(fcalc)[source]
Assign complex Fcalc values.
Automatically computes amplitude and phase from complex values.
- Parameters:
fcalc (torch.Tensor) – Complex structure factors with shape (N,).
- Raises:
ValueError – If fcalc length doesn’t match HKL length.
Examples
# Create complex Fcalc values fcalc = torch.randn(len(dataset), dtype=torch.complex64) dataset.set_fcalc(fcalc) print(dataset.fcalc_amp[:5]) # Amplitudes print(dataset.fcalc_phase[:5]) # Phases in radians
- write_mtz(filepath)[source]
Write Fcalc to MTZ file.
- Parameters:
filepath (str) – Output MTZ filename.
- Raises:
ValueError – If no Fcalc values have been set.
Examples
dataset.set_fcalc(fcalc_values) dataset.write_mtz('calculated.mtz')
- write_mtz_as_fobs(filepath, sigma_frac=0.05, f_column='F-obs', sigf_column='SIGF-obs', phase_column='PHIF-model')[source]
Write Fcalc to MTZ as if it were observed data (F-obs columns).
Useful for creating simulated “experimental” MTZ files that can be read back by ReflectionData.load_mtz() as observed amplitudes.
- Parameters:
filepath (str) – Output MTZ filename.
sigma_frac (float, optional) – Sigma as a fraction of |F|. Default is 0.05 (5%).
f_column (str, optional) – Column name for amplitudes. Default is ‘F-obs’.
sigf_column (str, optional) – Column name for sigma. Default is ‘SIGF-obs’.
phase_column (str, optional) – Column name for model phases. Default is ‘PHIF-model’.
Examples
dataset.set_fcalc(fcalc_values) dataset.write_mtz_as_fobs('simulated_obs.mtz', sigma_frac=0.05)
- property spacegroup_hm: str | None
Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)
- class torchref.io.MTZReader(verbose=0, column_names=None)[source]
Bases:
objectReader for MTZ files containing crystallographic structure factor data.
This class reads MTZ files using reciprocalspaceship and extracts: - Miller indices (h, k, l) - Structure factor amplitudes or intensities - Associated uncertainties (sigma values) - R-free test set flags
- cell
Unit cell parameters [a, b, c, alpha, beta, gamma].
- Type:
np.ndarray
- spacegroup
Space group object.
- Type:
gemmi.SpaceGroup
Examples
reader = mtz.read('data.mtz', verbose=1) data_dict, cell, spacegroup = reader() print(f"Found {len(data_dict['HKL'])} reflections in {spacegroup.short_name()}")
- AMPLITUDE_PRIORITY = ['F-obs', 'FOBS', 'FP', 'F', 'F-obs-filtered', 'FOBS-filtered', 'F(+)', 'FPLUS', 'FMEAN', 'F-pk', 'F_pk', 'FO', 'FODD', 'F-model', 'FC', 'FCALC']
- INTENSITY_PRIORITY = ['I-obs', 'IOBS', 'I', 'IMEAN', 'I-obs-filtered', 'IOBS-filtered', 'I(+)', 'IPLUS', 'IP', 'I-pk', 'I_pk', 'IHLI', 'I_full', 'IOBS_full', 'IO']
- RFREE_FLAG_NAMES = ['R-free-flags', 'RFREE', 'FreeR_flag', 'FREE', 'R-free', 'Rfree', 'FREER', 'FREE_FLAG', 'test', 'TEST', 'free', 'Free']
- class torchref.io.PDBReader(verbose=0)[source]
Bases:
objectReader for PDB files containing atomic coordinate data.
This class reads PDB files and extracts atomic coordinates, properties, and crystallographic metadata.
- dataframe
DataFrame containing atomic data.
- Type:
pd.DataFrame
Examples
reader = pdb.read('structure.pdb', verbose=1) df, cell, spacegroup = reader() print(f"Loaded {len(df)} atoms")
- __init__(verbose=0)[source]
Initialize PDB reader.
- Parameters:
verbose (int, optional) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 0.
- class torchref.io.CIFReader(filepath=None, data_block=None, parse_all_blocks=False)[source]
Bases:
objectA dictionary-like reader for CIF/mmCIF files.
Loops are stored as pandas DataFrames. Other data is stored in a hierarchical dictionary structure.
- Parameters:
filepath (str, optional) – Path to CIF file to load immediately.
data_block (str, optional) – Specific data block name to read (e.g., ‘r1vlmsf’). If None and parse_all_blocks=False, reads the first data block. If None and parse_all_blocks=True, reads all data blocks.
parse_all_blocks (bool, default False) – If True, parse all data blocks and merge them into a single dictionary (useful for restraint files). If False, parse only the specified block or the first block.
- filepath
Path to the loaded CIF file.
- Type:
Path or None
- classmethod from_string(content, **kwargs)[source]
Create CIFReader from string content instead of a file.
- class torchref.io.ReflectionCIFReader(filepath, verbose=0, data_block=None)[source]
Bases:
objectReader for structure factor CIF files (e.g., *-sf.cif from PDB).
Handles extraction of: - Miller indices (h, k, l) - Structure factor amplitudes (F) and uncertainties (σF) - Intensities (I) and uncertainties (σI) - Phases and figures of merit - R-free flags - Unit cell and space group metadata
- Compatible with legacy MTZ reader interface:
reader = ReflectionCIFReader(‘7JI4-sf.cif’).read() data_dict, spacegroup, cell = reader()
- Example:
reader = ReflectionCIFReader(‘7JI4-sf.cif’) refln_data = reader.get_reflection_data() h, k, l = refln_data[‘h’], refln_data[‘k’], refln_data[‘l’] F_obs = refln_data[‘F_obs’]
- __init__(filepath, verbose=0, data_block=None)[source]
Initialize and load structure factor CIF file.
- read(filepath=None)[source]
Read a CIF file (for compatibility with legacy interface).
- Args:
filepath: Path to CIF file (optional, uses initialization path if not provided)
- Returns:
self for method chaining
- __call__()[source]
Get data in legacy MTZ-compatible format.
- Returns:
data (dict) – Dictionary with extracted data arrays: - ‘h’, ‘k’, ‘l’: Miller indices - ‘F’, ‘SIGF’: Amplitudes and sigmas (if available) - ‘I’, ‘SIGI’: Intensities and sigmas (if available) - ‘R-free-flags’: R-free test set flags (if available)
cell (numpy.ndarray) – Cell parameters [a, b, c, alpha, beta, gamma].
spacegroup (gemmi.SpaceGroup) – Space group object.
- Return type:
- get_reflection_data()[source]
Extract reflection data with standardized column names.
- Returns:
DataFrame with columns: - h, k, l: Miller indices - F_obs, sigma_F_obs: Observed amplitudes (if available) - I_obs, sigma_I_obs: Observed intensities (if available) - phase, fom: Phase and figure of merit (if available) - free_flag: R-free flags (if available)
- Return type:
Notes
Missing columns will be filled with NaN or appropriate defaults.
- get_miller_indices()[source]
Get Miller indices as Nx3 array.
- Returns:
Array of shape (N, 3) with h, k, l indices
- get_amplitudes()[source]
Get structure factor amplitudes and uncertainties.
- Returns:
Dict with keys ‘F’ and ‘sigma_F’, or None if not available
- get_intensities()[source]
Get intensities and uncertainties.
- Returns:
Dict with keys ‘I’ and ‘sigma_I’, or None if not available
- class torchref.io.ModelCIFReader(filepath, verbose=0)[source]
Bases:
objectReader for model/structure CIF files (e.g., *.cif from PDB).
Handles extraction of: - Atomic coordinates and properties - Alternative conformations - Anisotropic displacement parameters - Unit cell and space group
- Compatible with legacy PDB reader interface:
reader = ModelCIFReader(‘3E98.cif’).read() dataframe, cell, spacegroup = reader()
- Example:
reader = ModelCIFReader(‘3E98.cif’) atom_df = reader.get_atom_data() cell = reader.get_cell_parameters()
- read(filepath=None)[source]
Read a CIF file (for compatibility with legacy interface).
- Parameters:
filepath (str, optional) – Path to CIF file. Uses initialization path if not provided.
- Returns:
Self for method chaining.
- Return type:
- __call__()[source]
Get data in legacy PDB-compatible format.
- Returns:
dataframe (pandas.DataFrame) – Atom data with columns: ATOM, serial, name, altloc, resname, chainid, resseq, icode, x, y, z, occupancy, tempfactor, element, charge, anisou_flag, u11, u22, u33, u12, u13, u23.
cell (list) – Cell parameters [a, b, c, alpha, beta, gamma].
spacegroup (gemmi.SpaceGroup) – Space group object.
- Return type:
- get_atom_data()[source]
Extract atomic coordinate data in PDB-compatible format.
- Returns:
DataFrame with columns matching PDB format: - ATOM, serial, name, altloc, resname, chainid, resseq, icode - x, y, z, occupancy, tempfactor - element, charge - anisou_flag, u11, u22, u33, u12, u13, u23
- Return type:
- get_atom_data_by_model()[source]
Split atom data by
pdbx_PDB_model_num.For single-model files, returns
{1: dataframe}. For multi-model files, returns one DataFrame per model number.- Returns:
Mapping of model number to atom DataFrame.
- Return type:
dict of int -> pandas.DataFrame
- get_space_group()[source]
Extract space group name.
- Returns:
Space group name string. Returns “P 1” if not found.
- Return type:
- get_coordinates()[source]
Extract atomic coordinates as numpy array.
- Returns:
Nx3 array of [x, y, z] coordinates, or None if not available.
- Return type:
numpy.ndarray or None
- class torchref.io.RestraintCIFReader(filepath)[source]
Bases:
objectReader for chemical restraint dictionary CIF files (e.g., from monomer library).
Handles extraction of: - Bond restraints (ideal lengths and ESDs) - Angle restraints - Torsion/dihedral restraints - Planarity restraints - Chirality definitions
Validates that the file contains proper restraint parameters (not just structure definitions).
- Example:
reader = RestraintCIFReader(‘external_monomer_library/a/ALA.cif’) comp_data = reader.get_all_restraints() bond_df = comp_data[‘ALA’][‘bonds’]
- __init__(filepath)[source]
Initialize and load restraint CIF file.
- Parameters:
filepath (str) – Path to restraint dictionary CIF file.
- get_all_restraints()[source]
Extract all restraint data for all compounds with standardized column names.
- Returns:
Dictionary mapping compound ID to dict of restraint types:
{ 'ALA': { 'bonds': DataFrame(atom1, atom2, value, sigma), 'angles': DataFrame(atom1, atom2, atom3, value, sigma), 'torsions': DataFrame(atom1, atom2, atom3, atom4, value, sigma, periodicity), 'planes': DataFrame(atom, plane_id), 'chirals': DataFrame(atom_centre, atom1, atom2, atom3, volume_sign) }, ... }
- Return type:
- get_compound_restraints(comp_id)[source]
Extract restraints for a specific compound with standardized column names.
- Parameters:
comp_id (str) – Compound identifier (e.g., ‘ALA’).
- Returns:
Dictionary of restraint DataFrames with standardized columns:
{ 'bonds': DataFrame(atom1, atom2, value, sigma) 'angles': DataFrame(atom1, atom2, atom3, value, sigma) 'torsions': DataFrame(atom1, atom2, atom3, atom4, value, sigma, periodicity) 'planes': DataFrame(atom, plane_id) 'chirals': DataFrame(atom_centre, atom1, atom2, atom3, volume_sign) 'atoms': DataFrame(atom_id, type_symbol, charge, etc.) }
- Return type:
- class torchref.io.DataRouter(filepath, verbose=1)[source]
Bases:
objectAutomatic file type detection and reader selection.
This class examines a file and automatically selects the appropriate reader based on file extension and content.
- Parameters:
- filepath
Path to the file to read.
- Type:
Path
- data_type
Type of data detected (‘reflections’, ‘structure’, ‘restraints’, ‘ihm_ensemble’, or None).
- Type:
str or None
Examples
router = DataRouter("structure.cif") reader = router.get_reader() print(router.data_type) # 'structure'
- MTZ_EXTENSIONS = {'.mtz'}
- PDB_EXTENSIONS = {'.ent', '.pdb'}
- CIF_EXTENSIONS = {'.cif', '.mmcif'}
- get_reader()[source]
Get the appropriate reader for this file.
- Returns:
Reader instance (ReflectionCIFReader, ModelCIFReader, RestraintCIFReader, MTZ, or PDB depending on file type).
- Return type:
- Raises:
DataRouterError – If file type is not supported or cannot be determined.
- get_data()[source]
Get the data from the file using the appropriate reader.
This is a convenience method that calls get_reader() and then invokes the reader to get the data.
- Returns:
For reflections: (data_dict, cell, spacegroup) For structure: (dataframe, residues, spacegroup) For restraints: Restraint data (format depends on reader)
- Return type:
- classmethod route(filepath, verbose=1)[source]
Factory method to quickly route a file to the appropriate reader.
- Parameters:
- Returns:
Tuple of (reader, data_type) where: - reader: The appropriate reader instance - data_type: String indicating the type (‘reflections’, ‘structure’, ‘restraints’)
- Return type:
Examples
reader, data_type = DataRouter.route("7JI4-sf.cif") if data_type == 'reflections': data_dict, cell, spacegroup = reader()
- exception torchref.io.DataRouterError[source]
Bases:
ExceptionException raised when file type cannot be determined or is unsupported.
- class torchref.io.IHMEnsembleMapping(states=<factory>, model_groups=<factory>, cell=None, spacegroup=None, atom_data_per_state=None)[source]
Bases:
objectComplete mapping between IHM mmCIF categories and torchref structures.
This is the central interchange object: both
IHMReaderandIHMWriteroperate through it. It can also be constructed manually for programmatic workflows (e.g., building an IHM file from aKineticRefinementresult without reading one first).- Parameters:
states (List[IHMStateInfo]) – Structural states (one per base model in ModelCollection).
model_groups (List[IHMModelGroupInfo]) – Model groups / timepoints (one per timepoint in ModelCollection).
cell (list of float, optional) – Unit cell parameters
[a, b, c, alpha, beta, gamma].spacegroup (str, optional) – Space group name (Hermann-Mauguin notation).
atom_data_per_state (dict, optional) – Mapping of
state_id-> pandas DataFrame with atom data. Populated byIHMReader.read_atom_data().
- states: List[IHMStateInfo]
- model_groups: List[IHMModelGroupInfo]
- get_fractions_for_group(group_name)[source]
Return population fractions for a model group, ordered by state_id.
- identify_dark_group()[source]
Heuristic: identify the reference / dark group.
Returns the name of the first model group where a single state has population fraction >= 0.95, or
Noneif no such group exists.
- get_state_by_id(state_id)[source]
Look up a state by its ID.
- Raises:
KeyError – If no state with the given ID exists.
- get_group_by_name(name)[source]
Look up a model group by name.
- Raises:
KeyError – If no group with the given name exists.
- validate()[source]
Check internal consistency.
- Raises:
ValueError – If states are empty, fractions don’t reference valid states, or fractions don’t sum to ~1.0 for any group.
- __init__(states=<factory>, model_groups=<factory>, cell=None, spacegroup=None, atom_data_per_state=None)
- class torchref.io.IHMStateInfo(state_id, name, details='', model_num=1)[source]
Bases:
objectMetadata for a single structural state (e.g., ground state, intermediate).
- Parameters:
state_id (int) – Unique identifier matching
_ihm_multi_state_modeling.state_id.name (str) – Human-readable name (e.g.,
"ground_state","intermediate_1").details (str) – Free-text description of this state.
model_num (int) –
pdbx_PDB_model_numin the_atom_siteloop that corresponds to this state’s coordinates.
- __init__(state_id, name, details='', model_num=1)
- class torchref.io.IHMModelGroupInfo(group_id, name, state_fractions=<factory>, time_delay=None, time_delay_units='s')[source]
Bases:
objectMetadata for a model group (experimental condition / timepoint).
- Parameters:
group_id (int) – Unique identifier matching
_ihm_model_group.id.name (str) – Human-readable name (e.g.,
"dark","1ps","5ps").state_fractions (Dict[int, float]) – Mapping of
state_id-> population fraction for this group. Fractions should sum to 1.0.time_delay (float, optional) – Time delay in
time_delay_units(for time-resolved experiments).time_delay_units (str) – Units for
time_delay. Default"s"(seconds).
- __init__(group_id, name, state_fractions=<factory>, time_delay=None, time_delay_units='s')
- class torchref.io.RefinementMetadata(program='TORCHREF', program_version='', refinement_method='', resolution_high=None, resolution_low=None, n_reflections_work=None, n_reflections_test=None, n_reflections_all=None, percent_free=None, r_work=None, r_free=None, b_mean_overall=None, b_min=None, b_max=None, rmsd_bond_lengths=None, rmsd_bond_angles=None, n_atoms_total=None, n_atoms_protein=None, n_atoms_solvent=None, solvent_model_ksol=None, solvent_model_bsol=None, cell=None, spacegroup=None, title='', authors=<factory>, passthrough_pdb_remarks=<factory>, passthrough_cif_categories=<factory>, custom_remarks=<factory>)[source]
Bases:
objectUnified metadata for PDB headers and mmCIF categories.
Fields map to both PDB REMARK 3 lines and PDBx/mmCIF
_refinecategory items. Only populated (non-None) fields are rendered.- Parameters:
- classmethod from_refinement(refinement)[source]
Extract metadata from a completed Refinement object.
Reuses existing statistics from
collect_metrics(),get_rfactor(), and reflection data attributes. Silently skips any unavailable statistics.- Parameters:
refinement (torchref.refinement.Refinement) – A refinement object (after refinement is complete).
- classmethod from_pdb_file(filepath)[source]
Extract header metadata from an existing PDB file.
Captures TITLE, AUTHOR, and REMARK records for pass-through.
- classmethod from_cif_file(filepath)[source]
Extract refinement metadata from an existing mmCIF file.
Captures
_struct.title,_audit_author.name, and_refinecategory items for pass-through.
- merge(other)[source]
Merge other into self. Non-None values in other take precedence.
Pass-through containers are combined (not replaced).
- Parameters:
other (RefinementMetadata) – Metadata to merge in (takes precedence for non-None fields).
- Returns:
A new merged instance.
- Return type:
- render_pdb_header()[source]
Render metadata as PDB header records (REMARK 3, TITLE, AUTHOR).
- Returns:
Multi-line string ready to insert into a PDB file.
- Return type:
- render_cif_categories()[source]
Render metadata as mmCIF category dictionaries.
Returns a dict of dicts keyed by mmCIF category, with item names as keys and string values. Uses official PDBx/mmCIF field names.
- Returns:
Nested dictionary
{category: {field: value}}.- Return type:
- __init__(program='TORCHREF', program_version='', refinement_method='', resolution_high=None, resolution_low=None, n_reflections_work=None, n_reflections_test=None, n_reflections_all=None, percent_free=None, r_work=None, r_free=None, b_mean_overall=None, b_min=None, b_max=None, rmsd_bond_lengths=None, rmsd_bond_angles=None, n_atoms_total=None, n_atoms_protein=None, n_atoms_solvent=None, solvent_model_ksol=None, solvent_model_bsol=None, cell=None, spacegroup=None, title='', authors=<factory>, passthrough_pdb_remarks=<factory>, passthrough_cif_categories=<factory>, custom_remarks=<factory>)
Subpackages
- torchref.io.datasets package
CrystalDatasetCrystalDataset.hklCrystalDataset.FCrystalDataset.F_sigmaCrystalDataset.ICrystalDataset.I_sigmaCrystalDataset.rfree_flagsCrystalDataset.resolutionCrystalDataset.bin_indicesCrystalDataset.outlier_flagsCrystalDataset.phaseCrystalDataset.fomCrystalDataset.ECrystalDataset.E_squaredCrystalDataset.F_squared_correctedCrystalDataset.U_anisoCrystalDataset.radial_shell_indicesCrystalDataset.cellCrystalDataset.spacegroupCrystalDataset.deviceCrystalDataset.verboseCrystalDataset.rfree_sourceCrystalDataset.amplitude_sourceCrystalDataset.intensity_sourceCrystalDataset.phase_sourceCrystalDataset.wilson_bCrystalDataset.wilson_b_structureCrystalDataset.wilson_b_solventCrystalDataset.wilson_k_solCrystalDataset.outlier_detection_paramsCrystalDataset.__post_init__()CrystalDataset.save_state()CrystalDataset.load_state()CrystalDataset.__len__()CrystalDataset.__repr__()CrystalDataset.spacegroup_nameCrystalDataset.spacegroup_hmCrystalDataset.spacegroup_numberCrystalDataset.__init__()
ReflectionDataReflectionData.hklReflectionData.FReflectionData.F_sigmaReflectionData.IReflectionData.I_sigmaReflectionData.rfree_flagsReflectionData.cellReflectionData.spacegroupReflectionData.resolutionReflectionData.wilson_bReflectionData.sourceReflectionData.datasetReflectionData.last_opReflectionData.readerReflectionData.__post_init__()ReflectionData.load()ReflectionData.from_tensors()ReflectionData.load_mtz()ReflectionData.load_cif()ReflectionData.list_cif_data_blocks()ReflectionData.get_bins()ReflectionData.mean_res_per_bin()ReflectionData.mean_F_per_bin()ReflectionData.mean_sigma_per_bin()ReflectionData.regenerate_rfree_flags()ReflectionData.get_structure_factors()ReflectionData.get_structure_factors_with_sigma()ReflectionData.get_hkl()ReflectionData.filter_by_resolution()ReflectionData.get_mask()ReflectionData.cut_res()ReflectionData.get_rfree_masks()ReflectionData.get_work_set()ReflectionData.get_test_set()ReflectionData.get_max_res()ReflectionData.get_min_res()ReflectionData.__len__()ReflectionData.d_minReflectionData.__repr__()ReflectionData.get_valid_mask()ReflectionData.data_indexed()ReflectionData.__call__()ReflectionData.data_fill_masked()ReflectionData.__getitem__()ReflectionData.__select__()ReflectionData.sanitize_F()ReflectionData.check_all_data_types()ReflectionData.validate_hkl()ReflectionData.find_outliers()ReflectionData.get_log_ratio()ReflectionData.get_outlier_statistics()ReflectionData.unpack_one()ReflectionData.flag_suspicious_sigma()ReflectionData.dump()ReflectionData.write_mtz()ReflectionData.centricReflectionData.calc_patterson()ReflectionData.possible_hkl()ReflectionData.remap()ReflectionData.fill()ReflectionData.expand_to_p1()ReflectionData.reduce_to_spacegroup()ReflectionData.canonicalize()ReflectionData.get_scattering_vectors()ReflectionData.get_radial_shells()ReflectionData.fit_anisotropy()ReflectionData.setup_anisotropy()ReflectionData.apply_anisotropy_correction()ReflectionData.compute_e_values()ReflectionData.setup_scale()ReflectionData.get_corrected_data()ReflectionData.parameters()ReflectionData.__init__()
FcalcDatasetFcalcDataset.spacegroupFcalcDataset.fcalcFcalcDataset.fcalc_ampFcalcDataset.fcalc_phaseFcalcDataset.from_cell_and_resolution()FcalcDataset.set_fcalc()FcalcDataset.write_mtz()FcalcDataset.write_mtz_as_fobs()FcalcDataset.__repr__()FcalcDataset.spacegroup_nameFcalcDataset.spacegroup_hmFcalcDataset.spacegroup_numberFcalcDataset.__init__()
DatasetCollectionDatasetCollection.hklDatasetCollection.n_datasetsDatasetCollection.add_dataset()DatasetCollection.hklDatasetCollection.datasetsDatasetCollection.n_datasetsDatasetCollection.reference_datasetDatasetCollection.spacegroupDatasetCollection.__getitem__()DatasetCollection.__iter__()DatasetCollection.__len__()DatasetCollection.__contains__()DatasetCollection.__call__()DatasetCollection.scale()DatasetCollection.keys()DatasetCollection.values()DatasetCollection.items()DatasetCollection.__init__()DatasetCollection.get()DatasetCollection.__repr__()
- Submodules
Submodules
- torchref.io.cif module
- torchref.io.cif_readers module
CIFReaderCIFReader.dataCIFReader.filepathCIFReader.available_blocksCIFReader.__init__()CIFReader.from_string()CIFReader.load()CIFReader.write()CIFReader.__getitem__()CIFReader.__setitem__()CIFReader.__contains__()CIFReader.__len__()CIFReader.keys()CIFReader.values()CIFReader.items()CIFReader.get()CIFReader.__repr__()CIFReader.summary()
ReflectionCIFReaderReflectionCIFReader.__init__()ReflectionCIFReader.read()ReflectionCIFReader.__call__()ReflectionCIFReader.get_reflection_data()ReflectionCIFReader.has_miller_indices()ReflectionCIFReader.has_amplitudes()ReflectionCIFReader.has_intensities()ReflectionCIFReader.has_phases()ReflectionCIFReader.has_rfree_flags()ReflectionCIFReader.get_miller_indices()ReflectionCIFReader.get_amplitudes()ReflectionCIFReader.get_intensities()ReflectionCIFReader.get_cell_parameters()ReflectionCIFReader.get_space_group()
ModelCIFReaderModelCIFReader.__init__()ModelCIFReader.read()ModelCIFReader.__call__()ModelCIFReader.get_atom_data()ModelCIFReader.get_atom_data_by_model()ModelCIFReader.get_cell_parameters()ModelCIFReader.get_space_group()ModelCIFReader.has_coordinates()ModelCIFReader.has_cell_parameters()ModelCIFReader.has_space_group()ModelCIFReader.has_occupancy()ModelCIFReader.has_bfactor()ModelCIFReader.has_anisotropic_data()ModelCIFReader.get_coordinates()ModelCIFReader.get_atom_info()
RestraintCIFReaderRestraintCIFReader.__init__()RestraintCIFReader.get_all_restraints()RestraintCIFReader.get_compound_restraints()RestraintCIFReader.get_bond_restraints()RestraintCIFReader.get_compound_id()RestraintCIFReader.has_bond_restraints()RestraintCIFReader.has_angle_restraints()RestraintCIFReader.has_torsion_restraints()RestraintCIFReader.has_plane_restraints()RestraintCIFReader.has_chirality_restraints()
- torchref.io.data_router module
DataRouterErrorDataRouterDataRouter.filepathDataRouter.verboseDataRouter.data_typeDataRouter.file_formatDataRouter.readerDataRouter.MTZ_EXTENSIONSDataRouter.PDB_EXTENSIONSDataRouter.CIF_EXTENSIONSDataRouter.__init__()DataRouter.data_typeDataRouter.file_formatDataRouter.readerDataRouter.get_reader()DataRouter.get_data()DataRouter.route()DataRouter.__repr__()DataRouter.__str__()
- torchref.io.ihm module
- torchref.io.ihm_mapping module
- Concept Mapping
IHMStateInfoIHMModelGroupInfoIHMEnsembleMappingIHMEnsembleMapping.statesIHMEnsembleMapping.model_groupsIHMEnsembleMapping.cellIHMEnsembleMapping.spacegroupIHMEnsembleMapping.atom_data_per_stateIHMEnsembleMapping.get_state_ids()IHMEnsembleMapping.get_timepoint_names()IHMEnsembleMapping.get_fractions_for_group()IHMEnsembleMapping.identify_dark_group()IHMEnsembleMapping.get_state_by_id()IHMEnsembleMapping.get_group_by_name()IHMEnsembleMapping.validate()IHMEnsembleMapping.__init__()
- torchref.io.metadata module
RefinementMetadataRefinementMetadata.programRefinementMetadata.program_versionRefinementMetadata.refinement_methodRefinementMetadata.resolution_highRefinementMetadata.resolution_lowRefinementMetadata.n_reflections_workRefinementMetadata.n_reflections_testRefinementMetadata.n_reflections_allRefinementMetadata.percent_freeRefinementMetadata.r_workRefinementMetadata.r_freeRefinementMetadata.b_mean_overallRefinementMetadata.b_minRefinementMetadata.b_maxRefinementMetadata.rmsd_bond_lengthsRefinementMetadata.rmsd_bond_anglesRefinementMetadata.n_atoms_totalRefinementMetadata.n_atoms_proteinRefinementMetadata.n_atoms_solventRefinementMetadata.solvent_model_ksolRefinementMetadata.solvent_model_bsolRefinementMetadata.cellRefinementMetadata.spacegroupRefinementMetadata.titleRefinementMetadata.authorsRefinementMetadata.passthrough_pdb_remarksRefinementMetadata.passthrough_cif_categoriesRefinementMetadata.custom_remarksRefinementMetadata.to_dict()RefinementMetadata.from_dict()RefinementMetadata.from_refinement()RefinementMetadata.from_pdb_file()RefinementMetadata.from_cif_file()RefinementMetadata.merge()RefinementMetadata.render_pdb_header()RefinementMetadata.render_cif_categories()RefinementMetadata.__init__()
- torchref.io.mtz module
- torchref.io.pdb module