torchref.io.datasets package
Crystallographic dataset classes.
This module provides PyTorch-based dataset classes for handling crystallographic data:
CrystalDataset: Abstract base class
ReflectionData: Single crystal reflection dataset
FcalcDataset: Dataset for calculated structure factors
DatasetCollection: Container for multiple related datasets
Examples
from torchref.io.datasets import ReflectionData
data = ReflectionData(device='cuda')
data.load_mtz('observed.mtz')
print(f"Loaded {len(data)} reflections")
from torchref.io.datasets import FcalcDataset
fcalc = FcalcDataset.from_cell_and_resolution(
cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0],
spacegroup='P212121',
d_min=2.0,
)
from torchref.io.datasets import DatasetCollection
collection = DatasetCollection()
collection.add_dataset('native', native_data)
collection.add_dataset('derivative', derivative_data)
- class torchref.io.datasets.CrystalDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)[source]
Bases:
DeviceMixinBase dataclass for crystallographic datasets.
Defines all possible tensor fields (optional) and handles device management and serialization. Subclasses add domain-specific methods.
This lightweight design enables scaling to 1000s of datasets without the overhead of torch.nn.Module.
- Parameters:
device (torch.device) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.
verbose (int) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.
Examples
Basic usage:
data = CrystalDataset(device='cuda') data.hkl = torch.tensor([[1, 0, 0], [0, 1, 0]]) data.cpu() # Move all tensors to CPU
- save_state(path)[source]
Save dataset state to file.
- Parameters:
path (str) – Output file path.
Examples
Save to file:
data.save_state('reflection_data.pt')
- classmethod load_state(path, device=device(type='cpu'))[source]
Load dataset state from file.
- Parameters:
- Returns:
Loaded dataset.
- Return type:
Examples
Load from file:
data = ReflectionData.load_state('reflection_data.pt', device='cuda')
- property spacegroup_hm: str | None
Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)
- class torchref.io.datasets.ReflectionData(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)[source]
Bases:
CrystalDataset,DebugMixinContainer for crystallographic reflection data.
This class handles loading, processing, and accessing reflection data including Miller indices, structure factor amplitudes, intensities, and R-free flags. All data is stored as PyTorch tensors for GPU acceleration.
- Parameters:
- hkl
Miller indices of shape (N, 3), dtype int32.
- Type:
- F
Structure factor amplitudes of shape (N,), dtype float32.
- Type:
- F_sigma
Amplitude uncertainties of shape (N,), dtype float32.
- Type:
- I
Intensities of shape (N,), dtype float32.
- Type:
- I_sigma
Intensity uncertainties of shape (N,), dtype float32.
- Type:
- rfree_flags
R-free test set flags of shape (N,), dtype bool.
- Type:
- cell
Unit cell parameters [a, b, c, alpha, beta, gamma].
- Type:
- resolution
Resolution per reflection in Ångströms of shape (N,).
- Type:
Examples
Load reflection data from an MTZ file:
data = ReflectionData(verbose=1, device='cuda') data.load_mtz('data.mtz') print(f"Loaded {len(data.hkl)} reflections") print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
- source: ReflectionData | None = None
- __post_init__()[source]
Initialize non-dataclass attributes after dataclass init.
This is called automatically after the dataclass __init__.
- load(reader)[source]
Load reflection data using a data reader.
- Parameters:
reader (callable) – Data reader object that returns (data_dict, cell, spacegroup) when called. Can be MTZ, ReflectionCIFReader, or other compatible reader.
- Returns:
Self, for method chaining.
- Return type:
- Raises:
ValueError – If unit cell parameters are missing or no amplitude/intensity data found.
- classmethod from_tensors(hkl, F, F_sigma, cell, spacegroup, rfree_flags=None, device=device(type='cpu'), verbose=1)[source]
Construct ReflectionData directly from tensors.
- Parameters:
hkl (torch.Tensor) – Miller indices of shape (N, 3).
F (torch.Tensor) – Structure factor amplitudes of shape (N,).
F_sigma (torch.Tensor) – Amplitude uncertainties of shape (N,).
cell (Cell) – Unit cell parameters.
spacegroup (SpaceGroup) – Space group.
rfree_flags (torch.Tensor, optional) – R-free flags of shape (N,), dtype bool. If None, flags are generated automatically (2% free fraction).
device (str, optional) – Device for tensors. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
- Returns:
Fully initialized reflection data with all cleanup applied.
- Return type:
- load_mtz(path, column_names=None)[source]
Load reflection data from MTZ file.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- load_cif(path, data_block=None)[source]
Load reflection data from CIF file.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- static list_cif_data_blocks(path)[source]
List all data blocks available in a CIF file without loading data.
Useful for multi-dataset CIF files to inspect available blocks before loading a specific one.
- Parameters:
path (str) – Path to CIF file.
- Returns:
Names of all data blocks in the CIF file.
- Return type:
Examples
List and load a specific data block:
blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif') print(blocks) # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...] data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
- get_bins(n_bins=20, min_per_bin=100)[source]
Create resolution bins with approximately equal reflection counts.
- Parameters:
- Returns:
bin_indices (torch.Tensor) – Tensor of shape (N,) with bin index for each reflection.
n_bins (int) – Actual number of bins created (may be less than target for small datasets).
- Return type:
- mean_res_per_bin()[source]
Calculate mean resolution for each bin.
- Returns:
Mean resolution for each bin in Ångströms.
- Return type:
- Raises:
ValueError – If bins have not been created yet.
- mean_F_per_bin()[source]
Calculate mean structure factor amplitude per resolution bin.
- Returns:
Mean F per bin of shape (n_bins,).
- Return type:
- Raises:
ValueError – If bins have not been created yet.
- mean_sigma_per_bin()[source]
Calculate mean structure factor uncertainty per resolution bin.
- Returns:
Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.
- Return type:
torch.Tensor or None
- Raises:
ValueError – If bins have not been created yet.
- regenerate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100, seed=None, force=False)[source]
Regenerate R-free flags with resolution-stratified sampling.
- Parameters:
free_fraction (float, optional) – Fraction of reflections to mark as free. Default is 0.02 (2%).
n_bins (int, optional) – Target number of resolution bins. Default is 20.
min_per_bin (int, optional) – Minimum reflections per resolution bin. Default is 100.
seed (int, optional) – Random seed for reproducibility. Default is None.
force (bool, optional) – If True, regenerate even if flags already exist. Default is False.
Examples
Generate R-free flags:
# Generate 2% free reflections with reproducible seed data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42) # Generate 5% free with 10 bins, overwriting existing data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
- get_structure_factors(as_complex=False)[source]
Get structure factors, optionally as complex numbers.
- Parameters:
as_complex (bool, optional) – If True and phases available, return F*exp(i*phi). Default is False.
- Returns:
Structure factor amplitudes or complex structure factors.
- Return type:
- Raises:
ValueError – If no amplitude data is loaded.
- get_structure_factors_with_sigma()[source]
Get structure factor amplitudes and their uncertainties.
- Returns:
F (torch.Tensor) – Structure factor amplitudes of shape (N,).
F_sigma (torch.Tensor or None) – Uncertainties of shape (N,), or None if not available.
- Raises:
ValueError – If no amplitude data is loaded.
- Return type:
Examples
Get amplitudes with uncertainties:
F, sigma_F = data.get_structure_factors_with_sigma() if sigma_F is not None: weighted_residual = (F_obs - F_calc) / sigma_F
- get_hkl()[source]
Return Miller indices for valid reflections.
- Returns:
Miller indices of shape (N, 3), dtype int32.
- Return type:
- Raises:
ValueError – If no Miller indices are loaded.
- filter_by_resolution(d_min=None, d_max=None)[source]
Filter reflections by resolution range.
Adds a boolean mask to self.masks for the specified resolution range.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
- get_mask()[source]
Return combined mask from all active filters.
- Returns:
Boolean mask combining all filter conditions.
- Return type:
- cut_res(highres=None, lowres=None)[source]
Filter reflections by resolution range.
Alias for filter_by_resolution with more intuitive naming.
- Parameters:
- Returns:
Self, for method chaining.
- Return type:
Examples
Filter by resolution:
# Keep reflections between 50 Å and 1.5 Å filtered = data.cut_res(highres=1.5, lowres=50.0) # Keep only high-resolution data (< 2 Å) high_res = data.cut_res(highres=1.0, lowres=2.0)
- get_rfree_masks()[source]
Get boolean masks for work and test (free) sets.
- Returns:
work_mask (torch.Tensor or None) – Boolean tensor for work set (flag != 0).
test_mask (torch.Tensor or None) – Boolean tensor for test/free set (flag == 0). Both are None if no R-free flags are available.
- Return type:
Examples
Separate work and test sets:
work_mask, test_mask = data.get_rfree_masks() if work_mask is not None: F_work = data.F[work_mask] F_test = data.F[test_mask]
- get_work_set()[source]
Get structure factors for the work set (R-free flag != 0).
- Returns:
F_work (torch.Tensor) – Structure factors for work set.
sigma_work (torch.Tensor or None) – Uncertainties for work set, or None if not available.
- Return type:
Notes
Returns full dataset with warning if no R-free flags available.
- get_test_set()[source]
Get structure factors for the test set (R-free flag == 0).
- Returns:
F_test (torch.Tensor) – Structure factors for test/free set.
sigma_test (torch.Tensor or None) – Uncertainties for test set, or None if not available.
- Raises:
ValueError – If no R-free flags are available.
- Return type:
- get_max_res()[source]
Return maximum resolution (lowest d-spacing).
- Returns:
Maximum resolution in Ångströms.
- Return type:
- get_min_res()[source]
Return minimum resolution (highest d-spacing).
- Returns:
Minimum resolution in Ångströms.
- Return type:
- __len__()[source]
Return number of reflections.
- Returns:
Number of reflections in the dataset.
- Return type:
- property d_min: float | None
Return maximum resolution (lowest d-spacing).
- Returns:
Maximum resolution in Ångströms.
- Return type:
- __repr__()[source]
Return string representation.
- Returns:
Summary of reflection data including count, sources, and resolution.
- Return type:
- get_valid_mask()[source]
Return the combined validity mask for all reflections.
This is the mask used to filter reflections in forward(). True indicates a valid (included) reflection, False indicates an excluded one.
- Returns:
Boolean mask of shape (N,) where N is the total number of reflections. True = valid/included, False = invalid/excluded.
- Return type:
Examples
Check validity:
mask = data.get_valid_mask() print(f"{mask.sum()} of {len(mask)} reflections are valid")
- data_indexed()[source]
Return reflection data as indexed (filtered) tensors.
This method filters out invalid reflections and returns smaller tensors containing only valid data. Useful for operations that don’t support MaskedTensors or for writing output files.
- Returns:
hkl (torch.Tensor) – Miller indices of shape (M, 3) where M is number of valid reflections.
F (torch.Tensor) – Structure factor amplitudes of shape (M,).
F_sigma (torch.Tensor or None) – Uncertainties of shape (M,) or None.
rfree_flags (torch.Tensor or None) – R-free flags of shape (M,) or None.
- Return type:
See also
forwardMain method returning MaskedTensors.
Examples
Get indexed data for file writing:
hkl, F, sigma, rfree = data.data_indexed() F_np = F.cpu().numpy() # Safe for writing to files
- __call__(mask=True, scale=True)[source]
Return core reflection data with MaskedTensors for F and sigma.
F and F_sigma are returned as MaskedTensors which keep all reflections but mark invalid ones as masked. Aggregation operations (sum, mean, etc.) automatically skip masked values. HKL and rfree_flags remain regular tensors.
- Parameters:
- Returns:
hkl (torch.Tensor) – Miller indices of shape (N, 3). Full size, unfiltered.
F (MaskedTensor) – Structure factor amplitudes of shape (N,) with invalid reflections masked.
F_sigma (MaskedTensor or None) – Uncertainties of shape (N,) with invalid reflections masked, or None.
rfree_flags (torch.Tensor or None) – R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.
- Return type:
Notes
MaskedTensors:
Are PyTorch tensors with an associated boolean mask
Aggregations (sum, mean, etc.) skip masked values automatically
Element-wise operations preserve the mask
Use .get_data() and .get_mask() to access underlying data
Use .to_tensor(fill_value) to convert back to regular tensor
Note: MaskedTensor is in prototype stage in PyTorch
Loss functions and targets extract valid data from MaskedTensors before computation to work correctly with complex F_calc values.
Examples
Access reflection data with MaskedTensors:
hkl, F, sigma, rfree = data() print(F.shape) # Full shape print(F.sum()) # Only sums valid (unmasked) values # Access underlying data valid_mask = F.get_mask() F_values = F.get_data()[valid_mask]
- data_fill_masked(mode='mean')[source]
Return data tensors with missing or flagged reflections filled with.
- args:
- mode: str, optional
‘mean’ : fill missing/flagged with mean of present data ‘zero’ : fill missing/flagged with zero
- __getitem__(key)[source]
Index into the reflection dataset.
- Parameters:
key (torch.Tensor) – Boolean mask or integer indices for selection.
- Returns:
New ReflectionData object with selected reflections.
- Return type:
- __select__(indices, op=None)[source]
Select reflections by boolean mask or integer indices.
Iterates over all dataclass fields generically, so new tensor fields are handled automatically.
- Parameters:
indices (torch.Tensor) – Boolean mask of shape (N,) or integer indices for selection.
op (str, optional) – Operation name for tracking purposes.
- Returns:
New ReflectionData object with selected reflections.
- Return type:
- sanitize_F()[source]
Remove invalid values from structure factors.
Adds a mask to filter out NaN, Inf, and non-positive values from F and F_sigma.
- validate_hkl(hkl_ref)[source]
Expand dataset to match a reference HKL set.
Reorders and expands the current dataset to match the reference HKL set. Reflections present in the reference but missing from this dataset are filled with placeholder values and masked out. This ensures all datasets aligned to the same reference have identical shapes and can be processed together without data loss from intersection operations.
- Parameters:
hkl_ref (torch.Tensor) – Reference Miller indices tensor of shape (N, 3), dtype int32. This defines the canonical HKL ordering for all aligned datasets.
- Returns:
Self, modified in-place with expanded arrays matching hkl_ref.
- Return type:
Notes
After calling this method: - self.hkl will equal hkl_ref exactly - All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded - Missing reflections are filled with 0 (or appropriate defaults) - A mask ‘hkl_present’ is added marking which reflections have real data - forward() will return MaskedTensors that skip missing reflections
This approach avoids the problem where intersecting many datasets with different outliers/missing reflections causes exponential data loss.
Examples
Align multiple datasets to a common HKL set:
reference_hkl = data1.hkl.clone() data1.validate_hkl(reference_hkl) data2.validate_hkl(reference_hkl) # Now data1 and data2 have identical shapes assert data1.hkl.shape == data2.hkl.shape
- find_outliers(model, scaler, z_threshold=4.0)[source]
Identify outlier reflections based on log-ratio distribution.
Uses the fact that log(F_obs) - log(F_calc) should be normally distributed. Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.
- Parameters:
- Returns:
Boolean mask where True indicates outliers.
- Return type:
- get_log_ratio(model, scaler)[source]
Compute log-ratio between observed and calculated structure factors.
- Parameters:
- Returns:
Log-ratio values: log(F_obs) - log(F_calc).
- Return type:
- get_outlier_statistics()[source]
Get statistics about flagged outliers.
- Returns:
Dictionary containing: - n_outliers : int - n_total : int - fraction_outliers : float - detection_params : dict or None - outlier_resolution_stats : dict (if resolution available)
- Return type:
- unpack_one()[source]
Unpack one level of source.
Does not recurse fully and does not flag.
- Returns:
Parent source or self if no source.
- Return type:
- flag_suspicious_sigma(z_threshold=5.0)[source]
Flag sigma values that deviate significantly from expected distribution.
Sigma values from a detector should follow a log-normal distribution. Values with z-scores beyond threshold are flagged as suspicious.
- Parameters:
z_threshold (float, optional) – Z-score threshold for flagging outliers. Default is 5.0.
- dump()[source]
Dump all reflection data to console for debugging.
Prints type, shape, and device information for all attributes.
- write_mtz(fname, fcalc=None, model_ft=None)[source]
Write reflection data to MTZ file with optional map coefficients.
- Parameters:
fname (str) – Output MTZ filename.
fcalc (torch.Tensor, optional) – Complex calculated structure factors of shape (N,). If provided, computes phases and map coefficients.
model_ft (ModelFT, optional) – ModelFT object to compute fcalc if not provided.
Notes
- The MTZ file will contain canonical column names:
FP, SIGFP: Observed amplitudes and uncertainties
I, SIGI: Observed intensities and uncertainties (if available)
FreeR_flag: R-free test set flags
FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)
DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)
- Map coefficients are computed as:
2mFo-DFc: 2*Fo - Fc (filled to resolution limit)
mFo-DFc: Fo - Fc
Examples
Write MTZ with map coefficients:
data = ReflectionData().load_mtz('observed.mtz') model = Model().load_pdb('model.pdb') model_ft = ModelFT(model, data.cell, data.spacegroup) fcalc = model_ft.forward(data.hkl) data.write_mtz('output.mtz', fcalc=fcalc)
- property centric
Get boolean mask for centric reflections (full size, unfiltered).
Calculates it if not already present. Returns unfiltered centric flags matching the full HKL array size, consistent with how forward() returns full-size arrays when using MaskedTensors.
- Returns:
Boolean tensor of shape (N,) where N is total reflections. True indicates centric reflection, False indicates acentric.
- Return type:
torch.Tensor or None
- calc_patterson(grid_size=None, grid_sampling=1)[source]
Calculate Patterson map of the dataset.
The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw)) is computed via inverse FFT of F². Data is expanded to P1 symmetry using only observed reflections (no filling of missing data).
- Parameters:
- Returns:
Real-valued Patterson map of shape (Nx, Ny, Nz). Origin is at grid position [0, 0, 0].
- Return type:
- possible_hkl()[source]
Generate all possible HKL indices within the resolution limit.
- Returns:
Tensor of shape (M, 3) containing all possible Miller indices within the resolution limit defined by self.resolution.
- Return type:
- remap(new_hkl, index_mapping, phase_shifts=None, spacegroup=None, op_name='remap')[source]
Create new ReflectionData with remapped HKL set and data.
This is the core method for index-based transformations of reflection data. It handles remapping all tensor fields based on an index mapping, with support for missing reflections (indicated by -1 in index_mapping).
- Parameters:
new_hkl (torch.Tensor, shape (M, 3)) – New Miller indices.
index_mapping (torch.Tensor, shape (M,), dtype int64) – Maps new indices to original:
new[i] = old[index_mapping[i]]Values of -1 indicate missing reflections (filled with defaults).phase_shifts (torch.Tensor, optional, shape (M,)) – Phase offsets to apply (e.g., from symmetry translations).
spacegroup (str, int, gemmi.SpaceGroup, or None) – New spacegroup. If None, keeps original.
op_name (str) – Operation name for provenance tracking.
- Returns:
New object with remapped data. Missing reflections get: - 0.0 for F, I, phase, fom - 1.0 for F_sigma, I_sigma (conservative uncertainty) - True for masks[‘missing’]
- Return type:
- fill(d_min=None)[source]
Fill missing reflections within resolution limit.
Generates all possible reflections for the current spacegroup within the resolution limit, identifies which are missing, and creates a complete dataset. Missing reflections are filled with default values.
- Parameters:
d_min (float, optional) – High resolution limit in Angstroms. If None, uses the minimum resolution from the current dataset.
- Returns:
New ReflectionData with complete set of reflections. Missing reflections have: - F, I, phase, fom = 0.0 - F_sigma, I_sigma = 1.0 - masks[‘missing’] = True
- Return type:
Examples
Fill to completeness:
data = ReflectionData().load_mtz('data.mtz') data_filled = data.fill(d_min=2.0) print(f"Original: {len(data)}, Filled: {len(data_filled)}")
- expand_to_p1(include_friedel=True, remove_absences=True)[source]
Expand reflection data from asymmetric unit to P1.
Applies all symmetry operations from the current space group to generate all symmetry-equivalent reflections. Returns a NEW ReflectionData object with expanded reflections; does not modify self.
- Parameters:
- Returns:
New ReflectionData object with: - All symmetry-equivalent reflections - spacegroup set to P1 - Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags) - Recalculated resolution values - Provenance tracking via source and last_op attributes
- Return type:
Notes
The expansion handles tensor fields as follows:
hkl: Symmetry operations applied, Friedel mates added, duplicates removed
F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original
phase: Indexed + phase shift from translation component
resolution: Recalculated from expanded hkl + cell
bin_indices: Cleared (invalidated by expansion)
Examples
Expand to P1 for calculations:
# Load data in original space group data = ReflectionData().load_mtz('data.mtz') print(f"Original: {len(data)} reflections, {data.spacegroup}") # Expand to P1 data_p1 = data.expand_to_p1() print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}") # Without Friedel mates (for anomalous data) data_p1_anom = data.expand_to_p1(include_friedel=False)
- reduce_to_spacegroup(spacegroup, include_friedel=True, aggregation='mean')[source]
Reduce P1 reflection data to asymmetric unit of a target spacegroup.
This is the inverse of expand_to_p1(). Takes reflection data in P1 and merges symmetry-equivalent reflections into single ASU reflections using the specified aggregation function.
- Parameters:
spacegroup (str, int, or gemmi.SpaceGroup) – Target space group specification.
include_friedel (bool, default True) – If True, also merge Friedel mates when reducing.
aggregation (str, default 'mean') – Aggregation function for merging equivalent reflections: - ‘mean’: Average values (default, good for amplitudes) - ‘sum’: Sum values - ‘first’: Take first valid value (no averaging)
- Returns:
New ReflectionData with merged reflections in the target spacegroup.
- Return type:
Notes
Field handling during reduction:
F, I: Aggregated using specified function
F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for ‘mean’
phase: Aggregated via complex averaging (handles phase wrapping)
fom: Weighted by amplitude during phase averaging
rfree_flags: OR operation (True if any equivalent is True)
Examples
Reduce symmetry-expanded data:
# Expand to P1 for calculations, then reduce back data_p1 = data.expand_to_p1() # ... modify F_p1 ... data_merged = data_p1.reduce_to_spacegroup('P21') # Reduce with sum instead of mean data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
- canonicalize(include_friedel=True)[source]
Return new ReflectionData with HKL in standard CCP4 ASU form.
Remaps all Miller indices to the canonical CCP4 asymmetric unit representative using
gemmi.ReciprocalAsu, adjusts phases accordingly, and sorts reflections lexicographically by (h, k, l).- Parameters:
include_friedel (bool, default True) – Whether Friedel mates are considered equivalent.
- Returns:
New object with canonicalized, sorted Miller indices.
- Return type:
- get_scattering_vectors()[source]
Get scattering vectors (s-vectors) from hkl and cell.
- The s-vector for a reflection hkl is defined as:
s = B* @ hkl
where B* is the reciprocal basis matrix.
- Returns:
s_vectors – Reciprocal space vectors in Angstroms^-1, shape (N, 3).
- Return type:
- Raises:
ValueError – If hkl or cell is not available.
- get_radial_shells(n_shells=20, d_min=None, d_max=None)[source]
Create uniform radial shells in 1/d space for normalization.
This creates shells with uniform spacing in reciprocal space (1/d), different from get_bins() which creates equal-count bins.
- Parameters:
- Returns:
shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).
shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).
shell_indices (torch.Tensor) – Shell index for each reflection, shape (N,). Values -1 for out-of-range.
- Return type:
- fit_anisotropy(n_shells=20, d_min=None, d_max=None, n_iterations=100, verbose=None)[source]
Fit anisotropy correction parameters to minimize CV within shells.
Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell.
- Parameters:
n_shells (int) – Number of resolution shells for variance calculation.
d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.
d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.
n_iterations (int) – Number of optimization iterations.
verbose (bool, optional) – Print progress. If None, uses self.verbose.
- Returns:
U – Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,). Also stored in self.U_aniso.
- Return type:
- Raises:
ValueError – If no amplitude data is available.
- setup_anisotropy(U_aniso=None)[source]
Setup anisotropy correction parameters. :param U_aniso: Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).
If None, uses Initializes as zeros.
- apply_anisotropy_correction(U_aniso=None)[source]
Apply anisotropy correction to F² values.
- Parameters:
U_aniso (torch.Tensor, optional) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,). If None, uses self.U_aniso (must have called fit_anisotropy first).
- Returns:
F_corrected (torch.Tensor) – Anisotropy-corrected F values, shape (N,).
sigma_F_corrected (torch.Tensor) – Uncertainties of corrected F values, shape (N,).
- Raises:
ValueError – If no U parameters available and none provided.
- Return type:
- compute_e_values(n_shells=20, d_min=None, d_max=None, apply_anisotropy=True, fit_anisotropy=True, verbose=None)[source]
Compute E-values with optional anisotropy correction.
E-values are normalized structure factors where <E²> = 1 within each resolution shell. Anisotropy correction can be applied first to account for directional variation in diffraction.
- Parameters:
n_shells (int) – Number of resolution shells for normalization.
d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.
d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.
apply_anisotropy (bool) – If True, apply anisotropy correction before E-value calculation.
fit_anisotropy (bool) – If True and apply_anisotropy is True, fit anisotropy parameters. If False and apply_anisotropy is True, uses existing self.U_aniso.
verbose (bool, optional) – Print progress. If None, uses self.verbose.
- Returns:
E – E-values, shape (N,). Also stored in self.E. self.E_squared is also populated with E² values.
- Return type:
- Raises:
ValueError – If no amplitude data is available.
Examples
Compute E-values with automatic anisotropy correction:
data = ReflectionData().load_mtz('data.mtz') E = data.compute_e_values(n_shells=30) print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")
Compute E-values without anisotropy correction:
E = data.compute_e_values(apply_anisotropy=False)
- get_corrected_data()[source]
Get scaled amplitude and uncertainty tensors.
Applies the current scale factor (from self.log_scale) to F and F_sigma.
- Returns:
Scaled F and F_sigma tensors. If log_scale is None, returns original.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- parameters()[source]
Get list of learnable parameters for optimization.
- Returns:
iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.
- Return type:
iter[Parameter]
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)
- class torchref.io.datasets.FcalcDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)[source]
Bases:
CrystalDatasetDataset for storing calculated structure factors.
Provides a lightweight container for Fcalc values with: - Cell and spacegroup information (using torchref.symmetry types) - HKL indices and resolution - Complex Fcalc with amplitude/phase decomposition - MTZ export capability
This class inherits from CrystalDataset and overrides the spacegroup field to store torchref.symmetry.SpaceGroup instead of gemmi.SpaceGroup.
- Parameters:
hkl (torch.Tensor, optional) – Miller indices of shape (N, 3).
resolution (torch.Tensor, optional) – Resolution per reflection of shape (N,).
cell (Cell, optional) – Unit cell object.
spacegroup (SpaceGroup, optional) – Space group object (torchref.symmetry.SpaceGroup).
fcalc (torch.Tensor, optional) – Complex structure factors of shape (N,).
fcalc_amp (torch.Tensor, optional) – Amplitudes |Fcalc| of shape (N,).
fcalc_phase (torch.Tensor, optional) – Phases in radians of shape (N,).
device (torch.device) – Device for tensors.
Examples
Create from cell and resolution:
from torchref.io.datasets import FcalcDataset dataset = FcalcDataset.from_cell_and_resolution( cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0], spacegroup='P212121', d_min=2.0, ) # Set Fcalc values (complex tensor) fcalc = torch.randn(len(dataset), dtype=torch.complex64) dataset.set_fcalc(fcalc) # Write to MTZ dataset.write_mtz('output.mtz')
- spacegroup: SpaceGroup | None = None
- static from_cell_and_resolution(cell, spacegroup, d_min=2.0, d_max=None, device=device(type='cpu'), dtype=torch.float32)[source]
Create FcalcDataset with HKL generated to given resolution.
- Parameters:
cell (torch.Tensor, list, or Cell) – Unit cell [a, b, c, alpha, beta, gamma] or Cell object.
spacegroup (SpaceGroupLike) – Space group (str, int, gemmi.SpaceGroup, or torchref.symmetry.SpaceGroup).
d_min (float, optional) – High resolution limit in Angstroms. Default is 2.0.
d_max (float, optional) – Low resolution limit in Angstroms. If provided, reflections with d-spacing > d_max are removed.
device (torch.device) – Target device.
dtype (torch.dtype) – Float dtype for tensors.
- Returns:
New dataset with HKL and resolution populated.
- Return type:
Examples
from torchref.symmetry import Cell, SpaceGroup cell = Cell([50.0, 60.0, 70.0, 90.0, 90.0, 90.0]) sg = SpaceGroup('P212121') dataset = FcalcDataset.from_cell_and_resolution( cell=cell, spacegroup=sg, d_min=2.0, ) print(f"Generated {len(dataset)} reflections")
- set_fcalc(fcalc)[source]
Assign complex Fcalc values.
Automatically computes amplitude and phase from complex values.
- Parameters:
fcalc (torch.Tensor) – Complex structure factors with shape (N,).
- Raises:
ValueError – If fcalc length doesn’t match HKL length.
Examples
# Create complex Fcalc values fcalc = torch.randn(len(dataset), dtype=torch.complex64) dataset.set_fcalc(fcalc) print(dataset.fcalc_amp[:5]) # Amplitudes print(dataset.fcalc_phase[:5]) # Phases in radians
- write_mtz(filepath)[source]
Write Fcalc to MTZ file.
- Parameters:
filepath (str) – Output MTZ filename.
- Raises:
ValueError – If no Fcalc values have been set.
Examples
dataset.set_fcalc(fcalc_values) dataset.write_mtz('calculated.mtz')
- write_mtz_as_fobs(filepath, sigma_frac=0.05, f_column='F-obs', sigf_column='SIGF-obs', phase_column='PHIF-model')[source]
Write Fcalc to MTZ as if it were observed data (F-obs columns).
Useful for creating simulated “experimental” MTZ files that can be read back by ReflectionData.load_mtz() as observed amplitudes.
- Parameters:
filepath (str) – Output MTZ filename.
sigma_frac (float, optional) – Sigma as a fraction of |F|. Default is 0.05 (5%).
f_column (str, optional) – Column name for amplitudes. Default is ‘F-obs’.
sigf_column (str, optional) – Column name for sigma. Default is ‘SIGF-obs’.
phase_column (str, optional) – Column name for model phases. Default is ‘PHIF-model’.
Examples
dataset.set_fcalc(fcalc_values) dataset.write_mtz_as_fobs('simulated_obs.mtz', sigma_frac=0.05)
- property spacegroup_hm: str | None
Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)
- class torchref.io.datasets.DatasetCollection(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)[source]
Bases:
CrystalDatasetContainer for multiple related crystal datasets.
All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out.
- Parameters:
- hkl
Common HKL set for all datasets.
- Type:
Examples
from torchref.io import DatasetCollection, ReflectionData collection = DatasetCollection(device='cuda') native = ReflectionData().load_mtz('native.mtz') derivative = ReflectionData().load_mtz('derivative.mtz') collection.add_dataset('native', native, set_as_reference=True) collection.add_dataset('derivative', derivative) for name, dataset in collection: print(f"{name}: {len(dataset)} reflections") # Access by name native_F = collection['native'].F
- add_dataset(name, dataset, set_as_reference=False)[source]
Add a dataset to the collection.
- Parameters:
name (str) – Identifier for this dataset.
dataset (ReflectionData) – The dataset to add.
set_as_reference (bool, optional) – If True, use this dataset’s HKL as the reference. Default is False, but the first dataset added automatically becomes the reference.
- Returns:
Self, for method chaining.
- Return type:
- Raises:
ValueError – If a dataset with the same name already exists.
Examples
collection = DatasetCollection() collection.add_dataset('native', native_data, set_as_reference=True) collection.add_dataset('derivative', derivative_data)
- property datasets: Dict[str, ReflectionData]
Access all datasets as a dictionary.
- __getitem__(name)[source]
Get dataset by name.
- __iter__()[source]
Iterate over (name, dataset) pairs in order of addition.
- Yields:
tuple of (str, ReflectionData) – Name and dataset for each dataset in collection.
- scale()[source]
Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps.
The collection instance, allowing for method chaining.
- Raises:
ValueError – If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh…
Notes
The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. “”” Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset.
- Returns:
for method chaining.
- Return type:
self
- __init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)
Submodules
- torchref.io.datasets.base module
CrystalDatasetCrystalDataset.hklCrystalDataset.FCrystalDataset.F_sigmaCrystalDataset.ICrystalDataset.I_sigmaCrystalDataset.rfree_flagsCrystalDataset.resolutionCrystalDataset.bin_indicesCrystalDataset.outlier_flagsCrystalDataset.phaseCrystalDataset.fomCrystalDataset.ECrystalDataset.E_squaredCrystalDataset.F_squared_correctedCrystalDataset.U_anisoCrystalDataset.radial_shell_indicesCrystalDataset.cellCrystalDataset.spacegroupCrystalDataset.deviceCrystalDataset.verboseCrystalDataset.rfree_sourceCrystalDataset.amplitude_sourceCrystalDataset.intensity_sourceCrystalDataset.phase_sourceCrystalDataset.wilson_bCrystalDataset.wilson_b_structureCrystalDataset.wilson_b_solventCrystalDataset.wilson_k_solCrystalDataset.outlier_detection_paramsCrystalDataset.__post_init__()CrystalDataset.save_state()CrystalDataset.load_state()CrystalDataset.__len__()CrystalDataset.__repr__()CrystalDataset.spacegroup_nameCrystalDataset.spacegroup_hmCrystalDataset.spacegroup_numberCrystalDataset.__init__()
- torchref.io.datasets.collection module
DatasetCollectionDatasetCollection.hklDatasetCollection.n_datasetsDatasetCollection.add_dataset()DatasetCollection.hklDatasetCollection.datasetsDatasetCollection.n_datasetsDatasetCollection.reference_datasetDatasetCollection.spacegroupDatasetCollection.__getitem__()DatasetCollection.__iter__()DatasetCollection.__len__()DatasetCollection.__contains__()DatasetCollection.__call__()DatasetCollection.scale()DatasetCollection.keys()DatasetCollection.values()DatasetCollection.items()DatasetCollection.__init__()DatasetCollection.get()DatasetCollection.__repr__()
- torchref.io.datasets.fcalc_data module
FcalcDatasetFcalcDataset.spacegroupFcalcDataset.fcalcFcalcDataset.fcalc_ampFcalcDataset.fcalc_phaseFcalcDataset.from_cell_and_resolution()FcalcDataset.set_fcalc()FcalcDataset.write_mtz()FcalcDataset.write_mtz_as_fobs()FcalcDataset.__repr__()FcalcDataset.spacegroup_nameFcalcDataset.spacegroup_hmFcalcDataset.spacegroup_numberFcalcDataset.__init__()
- torchref.io.datasets.reflection_data module
ReflectionDataReflectionData.hklReflectionData.FReflectionData.F_sigmaReflectionData.IReflectionData.I_sigmaReflectionData.rfree_flagsReflectionData.cellReflectionData.spacegroupReflectionData.resolutionReflectionData.wilson_bReflectionData.sourceReflectionData.datasetReflectionData.last_opReflectionData.readerReflectionData.__post_init__()ReflectionData.load()ReflectionData.from_tensors()ReflectionData.load_mtz()ReflectionData.load_cif()ReflectionData.list_cif_data_blocks()ReflectionData.get_bins()ReflectionData.mean_res_per_bin()ReflectionData.mean_F_per_bin()ReflectionData.mean_sigma_per_bin()ReflectionData.regenerate_rfree_flags()ReflectionData.get_structure_factors()ReflectionData.get_structure_factors_with_sigma()ReflectionData.get_hkl()ReflectionData.filter_by_resolution()ReflectionData.get_mask()ReflectionData.cut_res()ReflectionData.get_rfree_masks()ReflectionData.get_work_set()ReflectionData.get_test_set()ReflectionData.get_max_res()ReflectionData.get_min_res()ReflectionData.__len__()ReflectionData.d_minReflectionData.__repr__()ReflectionData.get_valid_mask()ReflectionData.data_indexed()ReflectionData.__call__()ReflectionData.data_fill_masked()ReflectionData.__getitem__()ReflectionData.__select__()ReflectionData.sanitize_F()ReflectionData.check_all_data_types()ReflectionData.validate_hkl()ReflectionData.find_outliers()ReflectionData.get_log_ratio()ReflectionData.get_outlier_statistics()ReflectionData.unpack_one()ReflectionData.flag_suspicious_sigma()ReflectionData.dump()ReflectionData.write_mtz()ReflectionData.centricReflectionData.calc_patterson()ReflectionData.possible_hkl()ReflectionData.remap()ReflectionData.fill()ReflectionData.expand_to_p1()ReflectionData.reduce_to_spacegroup()ReflectionData.canonicalize()ReflectionData.get_scattering_vectors()ReflectionData.get_radial_shells()ReflectionData.fit_anisotropy()ReflectionData.setup_anisotropy()ReflectionData.apply_anisotropy_correction()ReflectionData.compute_e_values()ReflectionData.setup_scale()ReflectionData.get_corrected_data()ReflectionData.parameters()ReflectionData.__init__()