torchref.io.datasets package

Crystallographic dataset classes.

This module provides PyTorch-based dataset classes for handling crystallographic data:

  • CrystalDataset: Abstract base class

  • ReflectionData: Single crystal reflection dataset

  • FcalcDataset: Dataset for calculated structure factors

  • DatasetCollection: Container for multiple related datasets

Examples

from torchref.io.datasets import ReflectionData
data = ReflectionData(device='cuda')
data.load_mtz('observed.mtz')
print(f"Loaded {len(data)} reflections")

from torchref.io.datasets import FcalcDataset
fcalc = FcalcDataset.from_cell_and_resolution(
    cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0],
    spacegroup='P212121',
    d_min=2.0,
)

from torchref.io.datasets import DatasetCollection
collection = DatasetCollection()
collection.add_dataset('native', native_data)
collection.add_dataset('derivative', derivative_data)
class torchref.io.datasets.CrystalDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)[source]

Bases: DeviceMixin

Base dataclass for crystallographic datasets.

Defines all possible tensor fields (optional) and handles device management and serialization. Subclasses add domain-specific methods.

This lightweight design enables scaling to 1000s of datasets without the overhead of torch.nn.Module.

Parameters:
  • device (torch.device) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.

  • verbose (int) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.

Examples

Basic usage:

data = CrystalDataset(device='cuda')
data.hkl = torch.tensor([[1, 0, 0], [0, 1, 0]])
data.cpu()  # Move all tensors to CPU
hkl: Tensor | None = None
F: Tensor | None = None
F_sigma: Tensor | None = None
I: Tensor | None = None
I_sigma: Tensor | None = None
rfree_flags: Tensor | None = None
resolution: Tensor | None = None
bin_indices: Tensor | None = None
outlier_flags: Tensor | None = None
phase: Tensor | None = None
fom: Tensor | None = None
E: Tensor | None = None
E_squared: Tensor | None = None
F_squared_corrected: Tensor | None = None
U_aniso: Tensor | None = None
radial_shell_indices: Tensor | None = None
cell: Cell | None = None
spacegroup: str | None = None
device: device
verbose: int = 1
rfree_source: str | None = None
amplitude_source: str | None = None
intensity_source: str | None = None
phase_source: str | None = None
wilson_b: float | None = None
wilson_b_structure: float | None = None
wilson_b_solvent: float | None = None
wilson_k_sol: float | None = None
outlier_detection_params: Dict[str, Any] | None = None
__post_init__()[source]

Initialize non-field attributes after dataclass init.

save_state(path)[source]

Save dataset state to file.

Parameters:

path (str) – Output file path.

Examples

Save to file:

data.save_state('reflection_data.pt')
classmethod load_state(path, device=device(type='cpu'))[source]

Load dataset state from file.

Parameters:
  • path (str) – Input file path.

  • device (str) – Device to load tensors onto.

Returns:

Loaded dataset.

Return type:

CrystalDataset

Examples

Load from file:

data = ReflectionData.load_state('reflection_data.pt', device='cuda')
__len__()[source]

Return number of reflections in dataset.

__repr__()[source]

String representation of dataset.

property spacegroup_name: str | None

Get space group name as string (short form, e.g., ‘P212121’).

property spacegroup_hm: str | None

Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).

property spacegroup_number: int | None

Get space group number (1-230).

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)
class torchref.io.datasets.ReflectionData(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)[source]

Bases: CrystalDataset, DebugMixin

Container for crystallographic reflection data.

This class handles loading, processing, and accessing reflection data including Miller indices, structure factor amplitudes, intensities, and R-free flags. All data is stored as PyTorch tensors for GPU acceleration.

Parameters:
  • verbose (int, optional) – Verbosity level for logging (0=silent, 1=normal, 2=debug). Default is 1.

  • device (str, optional) – Device to store tensors on (‘cpu’, ‘cuda’, ‘cuda:0’, etc.). Defaults to the configured device.current.

hkl

Miller indices of shape (N, 3), dtype int32.

Type:

torch.Tensor

F

Structure factor amplitudes of shape (N,), dtype float32.

Type:

torch.Tensor

F_sigma

Amplitude uncertainties of shape (N,), dtype float32.

Type:

torch.Tensor

I

Intensities of shape (N,), dtype float32.

Type:

torch.Tensor

I_sigma

Intensity uncertainties of shape (N,), dtype float32.

Type:

torch.Tensor

rfree_flags

R-free test set flags of shape (N,), dtype bool.

Type:

torch.Tensor

cell

Unit cell parameters [a, b, c, alpha, beta, gamma].

Type:

torch.Tensor

spacegroup

Space group symbol.

Type:

str

resolution

Resolution per reflection in Ångströms of shape (N,).

Type:

torch.Tensor

wilson_b

Overall Wilson B-factor in Ų.

Type:

float

Examples

Load reflection data from an MTZ file:

data = ReflectionData(verbose=1, device='cuda')
data.load_mtz('data.mtz')
print(f"Loaded {len(data.hkl)} reflections")
print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
source: ReflectionData | None = None
dataset: DataFrame | None = None
last_op: str | None = None
reader: Any | None = None
__post_init__()[source]

Initialize non-dataclass attributes after dataclass init.

This is called automatically after the dataclass __init__.

load(reader)[source]

Load reflection data using a data reader.

Parameters:

reader (callable) – Data reader object that returns (data_dict, cell, spacegroup) when called. Can be MTZ, ReflectionCIFReader, or other compatible reader.

Returns:

Self, for method chaining.

Return type:

ReflectionData

Raises:

ValueError – If unit cell parameters are missing or no amplitude/intensity data found.

classmethod from_tensors(hkl, F, F_sigma, cell, spacegroup, rfree_flags=None, device=device(type='cpu'), verbose=1)[source]

Construct ReflectionData directly from tensors.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • F (torch.Tensor) – Structure factor amplitudes of shape (N,).

  • F_sigma (torch.Tensor) – Amplitude uncertainties of shape (N,).

  • cell (Cell) – Unit cell parameters.

  • spacegroup (SpaceGroup) – Space group.

  • rfree_flags (torch.Tensor, optional) – R-free flags of shape (N,), dtype bool. If None, flags are generated automatically (2% free fraction).

  • device (str, optional) – Device for tensors. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

Returns:

Fully initialized reflection data with all cleanup applied.

Return type:

ReflectionData

load_mtz(path, column_names=None)[source]

Load reflection data from MTZ file.

Parameters:
  • path (str) – Path to MTZ file.

  • column_names (dict, optional) – Explicit column name mapping to override automatic detection. Supported keys: "F", "SIGF", "I", "SIGI". Example: {"F": "DFo", "SIGF": "sig_DFo"}.

Returns:

Self, for method chaining.

Return type:

ReflectionData

load_cif(path, data_block=None)[source]

Load reflection data from CIF file.

Parameters:
  • path (str) – Path to CIF file.

  • data_block (str, optional) – Specific data block name to read (e.g., ‘r1vlmsf’). If None, reads the first data block. Useful for multi-dataset CIF files.

Returns:

Self, for method chaining.

Return type:

ReflectionData

static list_cif_data_blocks(path)[source]

List all data blocks available in a CIF file without loading data.

Useful for multi-dataset CIF files to inspect available blocks before loading a specific one.

Parameters:

path (str) – Path to CIF file.

Returns:

Names of all data blocks in the CIF file.

Return type:

list of str

Examples

List and load a specific data block:

blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif')
print(blocks)  # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...]
data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
get_bins(n_bins=20, min_per_bin=100)[source]

Create resolution bins with approximately equal reflection counts.

Parameters:
  • n_bins (int, optional) – Target number of resolution bins. Default is 20.

  • min_per_bin (int, optional) – Minimum reflections per bin. Default is 100.

Returns:

  • bin_indices (torch.Tensor) – Tensor of shape (N,) with bin index for each reflection.

  • n_bins (int) – Actual number of bins created (may be less than target for small datasets).

Return type:

Tuple[Tensor, int]

mean_res_per_bin()[source]

Calculate mean resolution for each bin.

Returns:

Mean resolution for each bin in Ångströms.

Return type:

torch.Tensor

Raises:

ValueError – If bins have not been created yet.

mean_F_per_bin()[source]

Calculate mean structure factor amplitude per resolution bin.

Returns:

Mean F per bin of shape (n_bins,).

Return type:

torch.Tensor

Raises:

ValueError – If bins have not been created yet.

mean_sigma_per_bin()[source]

Calculate mean structure factor uncertainty per resolution bin.

Returns:

Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.

Return type:

torch.Tensor or None

Raises:

ValueError – If bins have not been created yet.

regenerate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100, seed=None, force=False)[source]

Regenerate R-free flags with resolution-stratified sampling.

Parameters:
  • free_fraction (float, optional) – Fraction of reflections to mark as free. Default is 0.02 (2%).

  • n_bins (int, optional) – Target number of resolution bins. Default is 20.

  • min_per_bin (int, optional) – Minimum reflections per resolution bin. Default is 100.

  • seed (int, optional) – Random seed for reproducibility. Default is None.

  • force (bool, optional) – If True, regenerate even if flags already exist. Default is False.

Examples

Generate R-free flags:

# Generate 2% free reflections with reproducible seed
data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42)
# Generate 5% free with 10 bins, overwriting existing
data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
get_structure_factors(as_complex=False)[source]

Get structure factors, optionally as complex numbers.

Parameters:

as_complex (bool, optional) – If True and phases available, return F*exp(i*phi). Default is False.

Returns:

Structure factor amplitudes or complex structure factors.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is loaded.

get_structure_factors_with_sigma()[source]

Get structure factor amplitudes and their uncertainties.

Returns:

  • F (torch.Tensor) – Structure factor amplitudes of shape (N,).

  • F_sigma (torch.Tensor or None) – Uncertainties of shape (N,), or None if not available.

Raises:

ValueError – If no amplitude data is loaded.

Return type:

Tuple[Tensor, Tensor | None]

Examples

Get amplitudes with uncertainties:

F, sigma_F = data.get_structure_factors_with_sigma()
if sigma_F is not None:
    weighted_residual = (F_obs - F_calc) / sigma_F
get_hkl()[source]

Return Miller indices for valid reflections.

Returns:

Miller indices of shape (N, 3), dtype int32.

Return type:

torch.Tensor

Raises:

ValueError – If no Miller indices are loaded.

filter_by_resolution(d_min=None, d_max=None)[source]

Filter reflections by resolution range.

Adds a boolean mask to self.masks for the specified resolution range.

Parameters:
  • d_min (float, optional) – Minimum resolution / high resolution cutoff (e.g., 1.5 Å).

  • d_max (float, optional) – Maximum resolution / low resolution cutoff (e.g., 50.0 Å).

Returns:

Self, for method chaining.

Return type:

ReflectionData

get_mask()[source]

Return combined mask from all active filters.

Returns:

Boolean mask combining all filter conditions.

Return type:

torch.Tensor

cut_res(highres=None, lowres=None)[source]

Filter reflections by resolution range.

Alias for filter_by_resolution with more intuitive naming.

Parameters:
  • highres (float, optional) – High resolution cutoff (small d-spacing, e.g., 1.5 Å). Keeps reflections with d >= highres.

  • lowres (float, optional) – Low resolution cutoff (large d-spacing, e.g., 50.0 Å). Keeps reflections with d <= lowres.

Returns:

Self, for method chaining.

Return type:

ReflectionData

Examples

Filter by resolution:

# Keep reflections between 50 Å and 1.5 Å
filtered = data.cut_res(highres=1.5, lowres=50.0)
# Keep only high-resolution data (< 2 Å)
high_res = data.cut_res(highres=1.0, lowres=2.0)
get_rfree_masks()[source]

Get boolean masks for work and test (free) sets.

Returns:

  • work_mask (torch.Tensor or None) – Boolean tensor for work set (flag != 0).

  • test_mask (torch.Tensor or None) – Boolean tensor for test/free set (flag == 0). Both are None if no R-free flags are available.

Return type:

Tuple[Tensor | None, Tensor | None]

Examples

Separate work and test sets:

work_mask, test_mask = data.get_rfree_masks()
if work_mask is not None:
    F_work = data.F[work_mask]
    F_test = data.F[test_mask]
get_work_set()[source]

Get structure factors for the work set (R-free flag != 0).

Returns:

  • F_work (torch.Tensor) – Structure factors for work set.

  • sigma_work (torch.Tensor or None) – Uncertainties for work set, or None if not available.

Return type:

Tuple[Tensor, Tensor | None]

Notes

Returns full dataset with warning if no R-free flags available.

get_test_set()[source]

Get structure factors for the test set (R-free flag == 0).

Returns:

  • F_test (torch.Tensor) – Structure factors for test/free set.

  • sigma_test (torch.Tensor or None) – Uncertainties for test set, or None if not available.

Raises:

ValueError – If no R-free flags are available.

Return type:

Tuple[Tensor, Tensor | None]

get_max_res()[source]

Return maximum resolution (lowest d-spacing).

Returns:

Maximum resolution in Ångströms.

Return type:

float

get_min_res()[source]

Return minimum resolution (highest d-spacing).

Returns:

Minimum resolution in Ångströms.

Return type:

float

__len__()[source]

Return number of reflections.

Returns:

Number of reflections in the dataset.

Return type:

int

property d_min: float | None

Return maximum resolution (lowest d-spacing).

Returns:

Maximum resolution in Ångströms.

Return type:

float

__repr__()[source]

Return string representation.

Returns:

Summary of reflection data including count, sources, and resolution.

Return type:

str

get_valid_mask()[source]

Return the combined validity mask for all reflections.

This is the mask used to filter reflections in forward(). True indicates a valid (included) reflection, False indicates an excluded one.

Returns:

Boolean mask of shape (N,) where N is the total number of reflections. True = valid/included, False = invalid/excluded.

Return type:

torch.Tensor

Examples

Check validity:

mask = data.get_valid_mask()
print(f"{mask.sum()} of {len(mask)} reflections are valid")
data_indexed()[source]

Return reflection data as indexed (filtered) tensors.

This method filters out invalid reflections and returns smaller tensors containing only valid data. Useful for operations that don’t support MaskedTensors or for writing output files.

Returns:

  • hkl (torch.Tensor) – Miller indices of shape (M, 3) where M is number of valid reflections.

  • F (torch.Tensor) – Structure factor amplitudes of shape (M,).

  • F_sigma (torch.Tensor or None) – Uncertainties of shape (M,) or None.

  • rfree_flags (torch.Tensor or None) – R-free flags of shape (M,) or None.

Return type:

Tuple[Tensor, Tensor, Tensor | None, Tensor | None]

See also

forward

Main method returning MaskedTensors.

Examples

Get indexed data for file writing:

hkl, F, sigma, rfree = data.data_indexed()
F_np = F.cpu().numpy()  # Safe for writing to files
__call__(mask=True, scale=True)[source]

Return core reflection data with MaskedTensors for F and sigma.

F and F_sigma are returned as MaskedTensors which keep all reflections but mark invalid ones as masked. Aggregation operations (sum, mean, etc.) automatically skip masked values. HKL and rfree_flags remain regular tensors.

Parameters:
  • mask (bool, optional) – If True, apply current masks to F and sigma. Default is True.

  • scale (bool, optional) – If True, apply scaling to F and sigma before returning.

Returns:

  • hkl (torch.Tensor) – Miller indices of shape (N, 3). Full size, unfiltered.

  • F (MaskedTensor) – Structure factor amplitudes of shape (N,) with invalid reflections masked.

  • F_sigma (MaskedTensor or None) – Uncertainties of shape (N,) with invalid reflections masked, or None.

  • rfree_flags (torch.Tensor or None) – R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.

Return type:

Tuple[Tensor, MaskedTensor, MaskedTensor, Tensor]

Notes

MaskedTensors:

  • Are PyTorch tensors with an associated boolean mask

  • Aggregations (sum, mean, etc.) skip masked values automatically

  • Element-wise operations preserve the mask

  • Use .get_data() and .get_mask() to access underlying data

  • Use .to_tensor(fill_value) to convert back to regular tensor

  • Note: MaskedTensor is in prototype stage in PyTorch

Loss functions and targets extract valid data from MaskedTensors before computation to work correctly with complex F_calc values.

Examples

Access reflection data with MaskedTensors:

hkl, F, sigma, rfree = data()
print(F.shape)  # Full shape
print(F.sum())  # Only sums valid (unmasked) values

# Access underlying data
valid_mask = F.get_mask()
F_values = F.get_data()[valid_mask]
data_fill_masked(mode='mean')[source]

Return data tensors with missing or flagged reflections filled with.

args:
mode: str, optional

‘mean’ : fill missing/flagged with mean of present data ‘zero’ : fill missing/flagged with zero

__getitem__(key)[source]

Index into the reflection dataset.

Parameters:

key (torch.Tensor) – Boolean mask or integer indices for selection.

Returns:

New ReflectionData object with selected reflections.

Return type:

ReflectionData

__select__(indices, op=None)[source]

Select reflections by boolean mask or integer indices.

Iterates over all dataclass fields generically, so new tensor fields are handled automatically.

Parameters:
  • indices (torch.Tensor) – Boolean mask of shape (N,) or integer indices for selection.

  • op (str, optional) – Operation name for tracking purposes.

Returns:

New ReflectionData object with selected reflections.

Return type:

ReflectionData

sanitize_F()[source]

Remove invalid values from structure factors.

Adds a mask to filter out NaN, Inf, and non-positive values from F and F_sigma.

check_all_data_types()[source]
validate_hkl(hkl_ref)[source]

Expand dataset to match a reference HKL set.

Reorders and expands the current dataset to match the reference HKL set. Reflections present in the reference but missing from this dataset are filled with placeholder values and masked out. This ensures all datasets aligned to the same reference have identical shapes and can be processed together without data loss from intersection operations.

Parameters:

hkl_ref (torch.Tensor) – Reference Miller indices tensor of shape (N, 3), dtype int32. This defines the canonical HKL ordering for all aligned datasets.

Returns:

Self, modified in-place with expanded arrays matching hkl_ref.

Return type:

ReflectionData

Notes

After calling this method: - self.hkl will equal hkl_ref exactly - All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded - Missing reflections are filled with 0 (or appropriate defaults) - A mask ‘hkl_present’ is added marking which reflections have real data - forward() will return MaskedTensors that skip missing reflections

This approach avoids the problem where intersecting many datasets with different outliers/missing reflections causes exponential data loss.

Examples

Align multiple datasets to a common HKL set:

reference_hkl = data1.hkl.clone()
data1.validate_hkl(reference_hkl)
data2.validate_hkl(reference_hkl)
# Now data1 and data2 have identical shapes
assert data1.hkl.shape == data2.hkl.shape
find_outliers(model, scaler, z_threshold=4.0)[source]

Identify outlier reflections based on log-ratio distribution.

Uses the fact that log(F_obs) - log(F_calc) should be normally distributed. Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.

Parameters:
  • model (ModelFT) – ModelFT object to compute structure factors.

  • scaler (Scaler) – Scaler object to scale calculated structure factors.

  • z_threshold (float, optional) – Z-score threshold to classify outliers. Default is 4.0.

Returns:

Boolean mask where True indicates outliers.

Return type:

torch.Tensor

get_log_ratio(model, scaler)[source]

Compute log-ratio between observed and calculated structure factors.

Parameters:
  • model (ModelFT) – ModelFT object to compute structure factors.

  • scaler (Scaler) – Scaler object to scale calculated structure factors.

Returns:

Log-ratio values: log(F_obs) - log(F_calc).

Return type:

torch.Tensor

get_outlier_statistics()[source]

Get statistics about flagged outliers.

Returns:

Dictionary containing: - n_outliers : int - n_total : int - fraction_outliers : float - detection_params : dict or None - outlier_resolution_stats : dict (if resolution available)

Return type:

dict

unpack_one()[source]

Unpack one level of source.

Does not recurse fully and does not flag.

Returns:

Parent source or self if no source.

Return type:

ReflectionData

flag_suspicious_sigma(z_threshold=5.0)[source]

Flag sigma values that deviate significantly from expected distribution.

Sigma values from a detector should follow a log-normal distribution. Values with z-scores beyond threshold are flagged as suspicious.

Parameters:

z_threshold (float, optional) – Z-score threshold for flagging outliers. Default is 5.0.

dump()[source]

Dump all reflection data to console for debugging.

Prints type, shape, and device information for all attributes.

write_mtz(fname, fcalc=None, model_ft=None)[source]

Write reflection data to MTZ file with optional map coefficients.

Parameters:
  • fname (str) – Output MTZ filename.

  • fcalc (torch.Tensor, optional) – Complex calculated structure factors of shape (N,). If provided, computes phases and map coefficients.

  • model_ft (ModelFT, optional) – ModelFT object to compute fcalc if not provided.

Notes

The MTZ file will contain canonical column names:
  • FP, SIGFP: Observed amplitudes and uncertainties

  • I, SIGI: Observed intensities and uncertainties (if available)

  • FreeR_flag: R-free test set flags

  • FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)

  • DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)

Map coefficients are computed as:
  • 2mFo-DFc: 2*Fo - Fc (filled to resolution limit)

  • mFo-DFc: Fo - Fc

Examples

Write MTZ with map coefficients:

data = ReflectionData().load_mtz('observed.mtz')
model = Model().load_pdb('model.pdb')
model_ft = ModelFT(model, data.cell, data.spacegroup)
fcalc = model_ft.forward(data.hkl)
data.write_mtz('output.mtz', fcalc=fcalc)
property centric

Get boolean mask for centric reflections (full size, unfiltered).

Calculates it if not already present. Returns unfiltered centric flags matching the full HKL array size, consistent with how forward() returns full-size arrays when using MaskedTensors.

Returns:

Boolean tensor of shape (N,) where N is total reflections. True indicates centric reflection, False indicates acentric.

Return type:

torch.Tensor or None

calc_patterson(grid_size=None, grid_sampling=1)[source]

Calculate Patterson map of the dataset.

The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw)) is computed via inverse FFT of F². Data is expanded to P1 symmetry using only observed reflections (no filling of missing data).

Parameters:
  • grid_size (tuple of int, optional) – Grid dimensions (Nx, Ny, Nz). If None, automatically determined from unit cell and resolution.

  • grid_sampling (float, optional) – Sampling interval for the grid. Default is 1. This sets the grid so that we sample twice as much as normal for a given resolution

Returns:

Real-valued Patterson map of shape (Nx, Ny, Nz). Origin is at grid position [0, 0, 0].

Return type:

torch.Tensor

possible_hkl()[source]

Generate all possible HKL indices within the resolution limit.

Returns:

Tensor of shape (M, 3) containing all possible Miller indices within the resolution limit defined by self.resolution.

Return type:

torch.Tensor

remap(new_hkl, index_mapping, phase_shifts=None, spacegroup=None, op_name='remap')[source]

Create new ReflectionData with remapped HKL set and data.

This is the core method for index-based transformations of reflection data. It handles remapping all tensor fields based on an index mapping, with support for missing reflections (indicated by -1 in index_mapping).

Parameters:
  • new_hkl (torch.Tensor, shape (M, 3)) – New Miller indices.

  • index_mapping (torch.Tensor, shape (M,), dtype int64) – Maps new indices to original: new[i] = old[index_mapping[i]] Values of -1 indicate missing reflections (filled with defaults).

  • phase_shifts (torch.Tensor, optional, shape (M,)) – Phase offsets to apply (e.g., from symmetry translations).

  • spacegroup (str, int, gemmi.SpaceGroup, or None) – New spacegroup. If None, keeps original.

  • op_name (str) – Operation name for provenance tracking.

Returns:

New object with remapped data. Missing reflections get: - 0.0 for F, I, phase, fom - 1.0 for F_sigma, I_sigma (conservative uncertainty) - True for masks[‘missing’]

Return type:

ReflectionData

fill(d_min=None)[source]

Fill missing reflections within resolution limit.

Generates all possible reflections for the current spacegroup within the resolution limit, identifies which are missing, and creates a complete dataset. Missing reflections are filled with default values.

Parameters:

d_min (float, optional) – High resolution limit in Angstroms. If None, uses the minimum resolution from the current dataset.

Returns:

New ReflectionData with complete set of reflections. Missing reflections have: - F, I, phase, fom = 0.0 - F_sigma, I_sigma = 1.0 - masks[‘missing’] = True

Return type:

ReflectionData

Examples

Fill to completeness:

data = ReflectionData().load_mtz('data.mtz')
data_filled = data.fill(d_min=2.0)
print(f"Original: {len(data)}, Filled: {len(data_filled)}")
expand_to_p1(include_friedel=True, remove_absences=True)[source]

Expand reflection data from asymmetric unit to P1.

Applies all symmetry operations from the current space group to generate all symmetry-equivalent reflections. Returns a NEW ReflectionData object with expanded reflections; does not modify self.

Parameters:
  • include_friedel (bool, default True) – Include Friedel mates (-h, -k, -l). For normal (non-anomalous) scattering, Friedel pairs have identical amplitudes.

  • remove_absences (bool, default True) – Remove systematically absent reflections from output.

Returns:

New ReflectionData object with: - All symmetry-equivalent reflections - spacegroup set to P1 - Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags) - Recalculated resolution values - Provenance tracking via source and last_op attributes

Return type:

ReflectionData

Notes

The expansion handles tensor fields as follows:

  • hkl: Symmetry operations applied, Friedel mates added, duplicates removed

  • F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original

  • phase: Indexed + phase shift from translation component

  • resolution: Recalculated from expanded hkl + cell

  • bin_indices: Cleared (invalidated by expansion)

Examples

Expand to P1 for calculations:

# Load data in original space group
data = ReflectionData().load_mtz('data.mtz')
print(f"Original: {len(data)} reflections, {data.spacegroup}")

# Expand to P1
data_p1 = data.expand_to_p1()
print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}")

# Without Friedel mates (for anomalous data)
data_p1_anom = data.expand_to_p1(include_friedel=False)
reduce_to_spacegroup(spacegroup, include_friedel=True, aggregation='mean')[source]

Reduce P1 reflection data to asymmetric unit of a target spacegroup.

This is the inverse of expand_to_p1(). Takes reflection data in P1 and merges symmetry-equivalent reflections into single ASU reflections using the specified aggregation function.

Parameters:
  • spacegroup (str, int, or gemmi.SpaceGroup) – Target space group specification.

  • include_friedel (bool, default True) – If True, also merge Friedel mates when reducing.

  • aggregation (str, default 'mean') – Aggregation function for merging equivalent reflections: - ‘mean’: Average values (default, good for amplitudes) - ‘sum’: Sum values - ‘first’: Take first valid value (no averaging)

Returns:

New ReflectionData with merged reflections in the target spacegroup.

Return type:

ReflectionData

Notes

Field handling during reduction:

  • F, I: Aggregated using specified function

  • F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for ‘mean’

  • phase: Aggregated via complex averaging (handles phase wrapping)

  • fom: Weighted by amplitude during phase averaging

  • rfree_flags: OR operation (True if any equivalent is True)

Examples

Reduce symmetry-expanded data:

# Expand to P1 for calculations, then reduce back
data_p1 = data.expand_to_p1()
# ... modify F_p1 ...
data_merged = data_p1.reduce_to_spacegroup('P21')

# Reduce with sum instead of mean
data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
canonicalize(include_friedel=True)[source]

Return new ReflectionData with HKL in standard CCP4 ASU form.

Remaps all Miller indices to the canonical CCP4 asymmetric unit representative using gemmi.ReciprocalAsu, adjusts phases accordingly, and sorts reflections lexicographically by (h, k, l).

Parameters:

include_friedel (bool, default True) – Whether Friedel mates are considered equivalent.

Returns:

New object with canonicalized, sorted Miller indices.

Return type:

ReflectionData

get_scattering_vectors()[source]

Get scattering vectors (s-vectors) from hkl and cell.

The s-vector for a reflection hkl is defined as:

s = B* @ hkl

where B* is the reciprocal basis matrix.

Returns:

s_vectors – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

Return type:

torch.Tensor

Raises:

ValueError – If hkl or cell is not available.

get_radial_shells(n_shells=20, d_min=None, d_max=None)[source]

Create uniform radial shells in 1/d space for normalization.

This creates shells with uniform spacing in reciprocal space (1/d), different from get_bins() which creates equal-count bins.

Parameters:
  • n_shells (int) – Number of radial shells. Default is 20.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

Returns:

  • shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).

  • shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).

  • shell_indices (torch.Tensor) – Shell index for each reflection, shape (N,). Values -1 for out-of-range.

Return type:

Tuple[Tensor, Tensor, Tensor]

fit_anisotropy(n_shells=20, d_min=None, d_max=None, n_iterations=100, verbose=None)[source]

Fit anisotropy correction parameters to minimize CV within shells.

Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell.

Parameters:
  • n_shells (int) – Number of resolution shells for variance calculation.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

  • n_iterations (int) – Number of optimization iterations.

  • verbose (bool, optional) – Print progress. If None, uses self.verbose.

Returns:

U – Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,). Also stored in self.U_aniso.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is available.

setup_anisotropy(U_aniso=None)[source]

Setup anisotropy correction parameters. :param U_aniso: Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).

If None, uses Initializes as zeros.

apply_anisotropy_correction(U_aniso=None)[source]

Apply anisotropy correction to F² values.

Parameters:

U_aniso (torch.Tensor, optional) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,). If None, uses self.U_aniso (must have called fit_anisotropy first).

Returns:

  • F_corrected (torch.Tensor) – Anisotropy-corrected F values, shape (N,).

  • sigma_F_corrected (torch.Tensor) – Uncertainties of corrected F values, shape (N,).

Raises:

ValueError – If no U parameters available and none provided.

Return type:

Tensor

compute_e_values(n_shells=20, d_min=None, d_max=None, apply_anisotropy=True, fit_anisotropy=True, verbose=None)[source]

Compute E-values with optional anisotropy correction.

E-values are normalized structure factors where <E²> = 1 within each resolution shell. Anisotropy correction can be applied first to account for directional variation in diffraction.

Parameters:
  • n_shells (int) – Number of resolution shells for normalization.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

  • apply_anisotropy (bool) – If True, apply anisotropy correction before E-value calculation.

  • fit_anisotropy (bool) – If True and apply_anisotropy is True, fit anisotropy parameters. If False and apply_anisotropy is True, uses existing self.U_aniso.

  • verbose (bool, optional) – Print progress. If None, uses self.verbose.

Returns:

E – E-values, shape (N,). Also stored in self.E. self.E_squared is also populated with E² values.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is available.

Examples

Compute E-values with automatic anisotropy correction:

data = ReflectionData().load_mtz('data.mtz')
E = data.compute_e_values(n_shells=30)
print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")

Compute E-values without anisotropy correction:

E = data.compute_e_values(apply_anisotropy=False)
setup_scale(scale=None)[source]

Set overall scale factor, parametrized in log space.

Parameters:

scale (float, optional) – If provided, sets the scale factor directly. If None, computes scale to make mean F equal to 1.0.

Returns:

The scale factor applied.

Return type:

float

get_corrected_data()[source]

Get scaled amplitude and uncertainty tensors.

Applies the current scale factor (from self.log_scale) to F and F_sigma.

Returns:

Scaled F and F_sigma tensors. If log_scale is None, returns original.

Return type:

Tuple[torch.Tensor, torch.Tensor]

parameters()[source]

Get list of learnable parameters for optimization.

Returns:

iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.

Return type:

iter[Parameter]

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)
class torchref.io.datasets.FcalcDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)[source]

Bases: CrystalDataset

Dataset for storing calculated structure factors.

Provides a lightweight container for Fcalc values with: - Cell and spacegroup information (using torchref.symmetry types) - HKL indices and resolution - Complex Fcalc with amplitude/phase decomposition - MTZ export capability

This class inherits from CrystalDataset and overrides the spacegroup field to store torchref.symmetry.SpaceGroup instead of gemmi.SpaceGroup.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices of shape (N, 3).

  • resolution (torch.Tensor, optional) – Resolution per reflection of shape (N,).

  • cell (Cell, optional) – Unit cell object.

  • spacegroup (SpaceGroup, optional) – Space group object (torchref.symmetry.SpaceGroup).

  • fcalc (torch.Tensor, optional) – Complex structure factors of shape (N,).

  • fcalc_amp (torch.Tensor, optional) – Amplitudes |Fcalc| of shape (N,).

  • fcalc_phase (torch.Tensor, optional) – Phases in radians of shape (N,).

  • device (torch.device) – Device for tensors.

Examples

Create from cell and resolution:

from torchref.io.datasets import FcalcDataset

dataset = FcalcDataset.from_cell_and_resolution(
    cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0],
    spacegroup='P212121',
    d_min=2.0,
)

# Set Fcalc values (complex tensor)
fcalc = torch.randn(len(dataset), dtype=torch.complex64)
dataset.set_fcalc(fcalc)

# Write to MTZ
dataset.write_mtz('output.mtz')
spacegroup: SpaceGroup | None = None
fcalc: Tensor | None = None
fcalc_amp: Tensor | None = None
fcalc_phase: Tensor | None = None
static from_cell_and_resolution(cell, spacegroup, d_min=2.0, d_max=None, device=device(type='cpu'), dtype=torch.float32)[source]

Create FcalcDataset with HKL generated to given resolution.

Parameters:
  • cell (torch.Tensor, list, or Cell) – Unit cell [a, b, c, alpha, beta, gamma] or Cell object.

  • spacegroup (SpaceGroupLike) – Space group (str, int, gemmi.SpaceGroup, or torchref.symmetry.SpaceGroup).

  • d_min (float, optional) – High resolution limit in Angstroms. Default is 2.0.

  • d_max (float, optional) – Low resolution limit in Angstroms. If provided, reflections with d-spacing > d_max are removed.

  • device (torch.device) – Target device.

  • dtype (torch.dtype) – Float dtype for tensors.

Returns:

New dataset with HKL and resolution populated.

Return type:

FcalcDataset

Examples

from torchref.symmetry import Cell, SpaceGroup

cell = Cell([50.0, 60.0, 70.0, 90.0, 90.0, 90.0])
sg = SpaceGroup('P212121')
dataset = FcalcDataset.from_cell_and_resolution(
    cell=cell, spacegroup=sg, d_min=2.0,
)
print(f"Generated {len(dataset)} reflections")
set_fcalc(fcalc)[source]

Assign complex Fcalc values.

Automatically computes amplitude and phase from complex values.

Parameters:

fcalc (torch.Tensor) – Complex structure factors with shape (N,).

Raises:

ValueError – If fcalc length doesn’t match HKL length.

Examples

# Create complex Fcalc values
fcalc = torch.randn(len(dataset), dtype=torch.complex64)
dataset.set_fcalc(fcalc)

print(dataset.fcalc_amp[:5])   # Amplitudes
print(dataset.fcalc_phase[:5]) # Phases in radians
write_mtz(filepath)[source]

Write Fcalc to MTZ file.

Parameters:

filepath (str) – Output MTZ filename.

Raises:

ValueError – If no Fcalc values have been set.

Examples

dataset.set_fcalc(fcalc_values)
dataset.write_mtz('calculated.mtz')
write_mtz_as_fobs(filepath, sigma_frac=0.05, f_column='F-obs', sigf_column='SIGF-obs', phase_column='PHIF-model')[source]

Write Fcalc to MTZ as if it were observed data (F-obs columns).

Useful for creating simulated “experimental” MTZ files that can be read back by ReflectionData.load_mtz() as observed amplitudes.

Parameters:
  • filepath (str) – Output MTZ filename.

  • sigma_frac (float, optional) – Sigma as a fraction of |F|. Default is 0.05 (5%).

  • f_column (str, optional) – Column name for amplitudes. Default is ‘F-obs’.

  • sigf_column (str, optional) – Column name for sigma. Default is ‘SIGF-obs’.

  • phase_column (str, optional) – Column name for model phases. Default is ‘PHIF-model’.

Examples

dataset.set_fcalc(fcalc_values)
dataset.write_mtz_as_fobs('simulated_obs.mtz', sigma_frac=0.05)
__repr__()[source]

String representation of dataset.

property spacegroup_name: str | None

Get space group name as string (short form, e.g., ‘P212121’).

property spacegroup_hm: str | None

Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).

property spacegroup_number: int | None

Get space group number (1-230).

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)
class torchref.io.datasets.DatasetCollection(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)[source]

Bases: CrystalDataset

Container for multiple related crystal datasets.

All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.

  • device (str, optional) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.

hkl

Common HKL set for all datasets.

Type:

torch.Tensor

n_datasets

Number of datasets in collection.

Type:

int

Examples

from torchref.io import DatasetCollection, ReflectionData

collection = DatasetCollection(device='cuda')

native = ReflectionData().load_mtz('native.mtz')
derivative = ReflectionData().load_mtz('derivative.mtz')

collection.add_dataset('native', native, set_as_reference=True)
collection.add_dataset('derivative', derivative)

for name, dataset in collection:
    print(f"{name}: {len(dataset)} reflections")

# Access by name
native_F = collection['native'].F
add_dataset(name, dataset, set_as_reference=False)[source]

Add a dataset to the collection.

Parameters:
  • name (str) – Identifier for this dataset.

  • dataset (ReflectionData) – The dataset to add.

  • set_as_reference (bool, optional) – If True, use this dataset’s HKL as the reference. Default is False, but the first dataset added automatically becomes the reference.

Returns:

Self, for method chaining.

Return type:

DatasetCollection

Raises:

ValueError – If a dataset with the same name already exists.

Examples

collection = DatasetCollection()
collection.add_dataset('native', native_data, set_as_reference=True)
collection.add_dataset('derivative', derivative_data)
property hkl: Tensor | None

Common HKL set for all datasets.

property datasets: Dict[str, ReflectionData]

Access all datasets as a dictionary.

property n_datasets: int

Number of datasets in collection.

property reference_dataset: str | None

Name of the reference dataset.

property spacegroup: str | None

Space group of the reference dataset.

__getitem__(name)[source]

Get dataset by name.

Parameters:

name (str) – Name of the dataset.

Returns:

The requested dataset.

Return type:

ReflectionData

Raises:

KeyError – If dataset name not found.

__iter__()[source]

Iterate over (name, dataset) pairs in order of addition.

Yields:

tuple of (str, ReflectionData) – Name and dataset for each dataset in collection.

__len__()[source]

Number of reflections in common HKL set.

__contains__(name)[source]

Check if dataset exists in collection.

__call__(mask=True)[source]

Return all datasets’ data scaled if scale factors are set.

Parameters:

mask (bool, optional) – Whether to apply masking. Default is True.

Returns:

Dictionary mapping name to (hkl, F, F_sigma, rfree) tuples.

Return type:

dict

scale()[source]

Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps.

The collection instance, allowing for method chaining.

Raises:

ValueError – If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh…

Notes

The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. “”” Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset.

Returns:

for method chaining.

Return type:

self

keys()[source]

Return list of dataset names.

values()[source]

Return list of datasets.

items()[source]

Return list of (name, dataset) tuples.

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)
get(name, default=None)[source]

Get dataset by name with default fallback.

__repr__()[source]

String representation of collection.

Submodules