torchref.io package

I/O module for crystallographic data files.

This module provides: - Dataset classes for handling reflection data - Format-specific readers and writers (MTZ, PDB, CIF) - Automatic format detection via DataRouter

High-level API

Load a single dataset:

from torchref.io import ReflectionData
data = ReflectionData(verbose=1)
data.load_mtz('structure.mtz')

Multi-dataset handling:

from torchref.io import DatasetCollection
collection = DatasetCollection()
collection.add_dataset('native', native_data)
collection.add_dataset('derivative', derivative_data)

Direct format access:

from torchref.io import mtz
reader = mtz.read('data.mtz')
data_dict, cell, spacegroup = reader()
class torchref.io.CrystalDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)[source]

Bases: DeviceMixin

Base dataclass for crystallographic datasets.

Defines all possible tensor fields (optional) and handles device management and serialization. Subclasses add domain-specific methods.

This lightweight design enables scaling to 1000s of datasets without the overhead of torch.nn.Module.

Parameters:
  • device (torch.device) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.

  • verbose (int) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.

Examples

Basic usage:

data = CrystalDataset(device='cuda')
data.hkl = torch.tensor([[1, 0, 0], [0, 1, 0]])
data.cpu()  # Move all tensors to CPU
hkl: Tensor | None = None
F: Tensor | None = None
F_sigma: Tensor | None = None
I: Tensor | None = None
I_sigma: Tensor | None = None
rfree_flags: Tensor | None = None
resolution: Tensor | None = None
bin_indices: Tensor | None = None
outlier_flags: Tensor | None = None
phase: Tensor | None = None
fom: Tensor | None = None
E: Tensor | None = None
E_squared: Tensor | None = None
F_squared_corrected: Tensor | None = None
U_aniso: Tensor | None = None
radial_shell_indices: Tensor | None = None
cell: Cell | None = None
spacegroup: str | None = None
device: device
verbose: int = 1
rfree_source: str | None = None
amplitude_source: str | None = None
intensity_source: str | None = None
phase_source: str | None = None
wilson_b: float | None = None
wilson_b_structure: float | None = None
wilson_b_solvent: float | None = None
wilson_k_sol: float | None = None
outlier_detection_params: Dict[str, Any] | None = None
__post_init__()[source]

Initialize non-field attributes after dataclass init.

save_state(path)[source]

Save dataset state to file.

Parameters:

path (str) – Output file path.

Examples

Save to file:

data.save_state('reflection_data.pt')
classmethod load_state(path, device=device(type='cpu'))[source]

Load dataset state from file.

Parameters:
  • path (str) – Input file path.

  • device (str) – Device to load tensors onto.

Returns:

Loaded dataset.

Return type:

CrystalDataset

Examples

Load from file:

data = ReflectionData.load_state('reflection_data.pt', device='cuda')
__len__()[source]

Return number of reflections in dataset.

__repr__()[source]

String representation of dataset.

property spacegroup_name: str | None

Get space group name as string (short form, e.g., ‘P212121’).

property spacegroup_hm: str | None

Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).

property spacegroup_number: int | None

Get space group number (1-230).

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None)
class torchref.io.ReflectionData(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)[source]

Bases: CrystalDataset, DebugMixin

Container for crystallographic reflection data.

This class handles loading, processing, and accessing reflection data including Miller indices, structure factor amplitudes, intensities, and R-free flags. All data is stored as PyTorch tensors for GPU acceleration.

Parameters:
  • verbose (int, optional) – Verbosity level for logging (0=silent, 1=normal, 2=debug). Default is 1.

  • device (str, optional) – Device to store tensors on (‘cpu’, ‘cuda’, ‘cuda:0’, etc.). Defaults to the configured device.current.

hkl

Miller indices of shape (N, 3), dtype int32.

Type:

torch.Tensor

F

Structure factor amplitudes of shape (N,), dtype float32.

Type:

torch.Tensor

F_sigma

Amplitude uncertainties of shape (N,), dtype float32.

Type:

torch.Tensor

I

Intensities of shape (N,), dtype float32.

Type:

torch.Tensor

I_sigma

Intensity uncertainties of shape (N,), dtype float32.

Type:

torch.Tensor

rfree_flags

R-free test set flags of shape (N,), dtype bool.

Type:

torch.Tensor

cell

Unit cell parameters [a, b, c, alpha, beta, gamma].

Type:

torch.Tensor

spacegroup

Space group symbol.

Type:

str

resolution

Resolution per reflection in Ångströms of shape (N,).

Type:

torch.Tensor

wilson_b

Overall Wilson B-factor in Ų.

Type:

float

Examples

Load reflection data from an MTZ file:

data = ReflectionData(verbose=1, device='cuda')
data.load_mtz('data.mtz')
print(f"Loaded {len(data.hkl)} reflections")
print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
source: ReflectionData | None = None
dataset: DataFrame | None = None
last_op: str | None = None
reader: Any | None = None
__post_init__()[source]

Initialize non-dataclass attributes after dataclass init.

This is called automatically after the dataclass __init__.

load(reader)[source]

Load reflection data using a data reader.

Parameters:

reader (callable) – Data reader object that returns (data_dict, cell, spacegroup) when called. Can be MTZ, ReflectionCIFReader, or other compatible reader.

Returns:

Self, for method chaining.

Return type:

ReflectionData

Raises:

ValueError – If unit cell parameters are missing or no amplitude/intensity data found.

classmethod from_tensors(hkl, F, F_sigma, cell, spacegroup, rfree_flags=None, device=device(type='cpu'), verbose=1)[source]

Construct ReflectionData directly from tensors.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • F (torch.Tensor) – Structure factor amplitudes of shape (N,).

  • F_sigma (torch.Tensor) – Amplitude uncertainties of shape (N,).

  • cell (Cell) – Unit cell parameters.

  • spacegroup (SpaceGroup) – Space group.

  • rfree_flags (torch.Tensor, optional) – R-free flags of shape (N,), dtype bool. If None, flags are generated automatically (2% free fraction).

  • device (str, optional) – Device for tensors. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

Returns:

Fully initialized reflection data with all cleanup applied.

Return type:

ReflectionData

load_mtz(path, column_names=None)[source]

Load reflection data from MTZ file.

Parameters:
  • path (str) – Path to MTZ file.

  • column_names (dict, optional) – Explicit column name mapping to override automatic detection. Supported keys: "F", "SIGF", "I", "SIGI". Example: {"F": "DFo", "SIGF": "sig_DFo"}.

Returns:

Self, for method chaining.

Return type:

ReflectionData

load_cif(path, data_block=None)[source]

Load reflection data from CIF file.

Parameters:
  • path (str) – Path to CIF file.

  • data_block (str, optional) – Specific data block name to read (e.g., ‘r1vlmsf’). If None, reads the first data block. Useful for multi-dataset CIF files.

Returns:

Self, for method chaining.

Return type:

ReflectionData

static list_cif_data_blocks(path)[source]

List all data blocks available in a CIF file without loading data.

Useful for multi-dataset CIF files to inspect available blocks before loading a specific one.

Parameters:

path (str) – Path to CIF file.

Returns:

Names of all data blocks in the CIF file.

Return type:

list of str

Examples

List and load a specific data block:

blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif')
print(blocks)  # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...]
data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
get_bins(n_bins=20, min_per_bin=100)[source]

Create resolution bins with approximately equal reflection counts.

Parameters:
  • n_bins (int, optional) – Target number of resolution bins. Default is 20.

  • min_per_bin (int, optional) – Minimum reflections per bin. Default is 100.

Returns:

  • bin_indices (torch.Tensor) – Tensor of shape (N,) with bin index for each reflection.

  • n_bins (int) – Actual number of bins created (may be less than target for small datasets).

Return type:

Tuple[Tensor, int]

mean_res_per_bin()[source]

Calculate mean resolution for each bin.

Returns:

Mean resolution for each bin in Ångströms.

Return type:

torch.Tensor

Raises:

ValueError – If bins have not been created yet.

mean_F_per_bin()[source]

Calculate mean structure factor amplitude per resolution bin.

Returns:

Mean F per bin of shape (n_bins,).

Return type:

torch.Tensor

Raises:

ValueError – If bins have not been created yet.

mean_sigma_per_bin()[source]

Calculate mean structure factor uncertainty per resolution bin.

Returns:

Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.

Return type:

torch.Tensor or None

Raises:

ValueError – If bins have not been created yet.

regenerate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100, seed=None, force=False)[source]

Regenerate R-free flags with resolution-stratified sampling.

Parameters:
  • free_fraction (float, optional) – Fraction of reflections to mark as free. Default is 0.02 (2%).

  • n_bins (int, optional) – Target number of resolution bins. Default is 20.

  • min_per_bin (int, optional) – Minimum reflections per resolution bin. Default is 100.

  • seed (int, optional) – Random seed for reproducibility. Default is None.

  • force (bool, optional) – If True, regenerate even if flags already exist. Default is False.

Examples

Generate R-free flags:

# Generate 2% free reflections with reproducible seed
data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42)
# Generate 5% free with 10 bins, overwriting existing
data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
get_structure_factors(as_complex=False)[source]

Get structure factors, optionally as complex numbers.

Parameters:

as_complex (bool, optional) – If True and phases available, return F*exp(i*phi). Default is False.

Returns:

Structure factor amplitudes or complex structure factors.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is loaded.

get_structure_factors_with_sigma()[source]

Get structure factor amplitudes and their uncertainties.

Returns:

  • F (torch.Tensor) – Structure factor amplitudes of shape (N,).

  • F_sigma (torch.Tensor or None) – Uncertainties of shape (N,), or None if not available.

Raises:

ValueError – If no amplitude data is loaded.

Return type:

Tuple[Tensor, Tensor | None]

Examples

Get amplitudes with uncertainties:

F, sigma_F = data.get_structure_factors_with_sigma()
if sigma_F is not None:
    weighted_residual = (F_obs - F_calc) / sigma_F
get_hkl()[source]

Return Miller indices for valid reflections.

Returns:

Miller indices of shape (N, 3), dtype int32.

Return type:

torch.Tensor

Raises:

ValueError – If no Miller indices are loaded.

filter_by_resolution(d_min=None, d_max=None)[source]

Filter reflections by resolution range.

Adds a boolean mask to self.masks for the specified resolution range.

Parameters:
  • d_min (float, optional) – Minimum resolution / high resolution cutoff (e.g., 1.5 Å).

  • d_max (float, optional) – Maximum resolution / low resolution cutoff (e.g., 50.0 Å).

Returns:

Self, for method chaining.

Return type:

ReflectionData

get_mask()[source]

Return combined mask from all active filters.

Returns:

Boolean mask combining all filter conditions.

Return type:

torch.Tensor

cut_res(highres=None, lowres=None)[source]

Filter reflections by resolution range.

Alias for filter_by_resolution with more intuitive naming.

Parameters:
  • highres (float, optional) – High resolution cutoff (small d-spacing, e.g., 1.5 Å). Keeps reflections with d >= highres.

  • lowres (float, optional) – Low resolution cutoff (large d-spacing, e.g., 50.0 Å). Keeps reflections with d <= lowres.

Returns:

Self, for method chaining.

Return type:

ReflectionData

Examples

Filter by resolution:

# Keep reflections between 50 Å and 1.5 Å
filtered = data.cut_res(highres=1.5, lowres=50.0)
# Keep only high-resolution data (< 2 Å)
high_res = data.cut_res(highres=1.0, lowres=2.0)
get_rfree_masks()[source]

Get boolean masks for work and test (free) sets.

Returns:

  • work_mask (torch.Tensor or None) – Boolean tensor for work set (flag != 0).

  • test_mask (torch.Tensor or None) – Boolean tensor for test/free set (flag == 0). Both are None if no R-free flags are available.

Return type:

Tuple[Tensor | None, Tensor | None]

Examples

Separate work and test sets:

work_mask, test_mask = data.get_rfree_masks()
if work_mask is not None:
    F_work = data.F[work_mask]
    F_test = data.F[test_mask]
get_work_set()[source]

Get structure factors for the work set (R-free flag != 0).

Returns:

  • F_work (torch.Tensor) – Structure factors for work set.

  • sigma_work (torch.Tensor or None) – Uncertainties for work set, or None if not available.

Return type:

Tuple[Tensor, Tensor | None]

Notes

Returns full dataset with warning if no R-free flags available.

get_test_set()[source]

Get structure factors for the test set (R-free flag == 0).

Returns:

  • F_test (torch.Tensor) – Structure factors for test/free set.

  • sigma_test (torch.Tensor or None) – Uncertainties for test set, or None if not available.

Raises:

ValueError – If no R-free flags are available.

Return type:

Tuple[Tensor, Tensor | None]

get_max_res()[source]

Return maximum resolution (lowest d-spacing).

Returns:

Maximum resolution in Ångströms.

Return type:

float

get_min_res()[source]

Return minimum resolution (highest d-spacing).

Returns:

Minimum resolution in Ångströms.

Return type:

float

__len__()[source]

Return number of reflections.

Returns:

Number of reflections in the dataset.

Return type:

int

property d_min: float | None

Return maximum resolution (lowest d-spacing).

Returns:

Maximum resolution in Ångströms.

Return type:

float

__repr__()[source]

Return string representation.

Returns:

Summary of reflection data including count, sources, and resolution.

Return type:

str

get_valid_mask()[source]

Return the combined validity mask for all reflections.

This is the mask used to filter reflections in forward(). True indicates a valid (included) reflection, False indicates an excluded one.

Returns:

Boolean mask of shape (N,) where N is the total number of reflections. True = valid/included, False = invalid/excluded.

Return type:

torch.Tensor

Examples

Check validity:

mask = data.get_valid_mask()
print(f"{mask.sum()} of {len(mask)} reflections are valid")
data_indexed()[source]

Return reflection data as indexed (filtered) tensors.

This method filters out invalid reflections and returns smaller tensors containing only valid data. Useful for operations that don’t support MaskedTensors or for writing output files.

Returns:

  • hkl (torch.Tensor) – Miller indices of shape (M, 3) where M is number of valid reflections.

  • F (torch.Tensor) – Structure factor amplitudes of shape (M,).

  • F_sigma (torch.Tensor or None) – Uncertainties of shape (M,) or None.

  • rfree_flags (torch.Tensor or None) – R-free flags of shape (M,) or None.

Return type:

Tuple[Tensor, Tensor, Tensor | None, Tensor | None]

See also

forward

Main method returning MaskedTensors.

Examples

Get indexed data for file writing:

hkl, F, sigma, rfree = data.data_indexed()
F_np = F.cpu().numpy()  # Safe for writing to files
__call__(mask=True, scale=True)[source]

Return core reflection data with MaskedTensors for F and sigma.

F and F_sigma are returned as MaskedTensors which keep all reflections but mark invalid ones as masked. Aggregation operations (sum, mean, etc.) automatically skip masked values. HKL and rfree_flags remain regular tensors.

Parameters:
  • mask (bool, optional) – If True, apply current masks to F and sigma. Default is True.

  • scale (bool, optional) – If True, apply scaling to F and sigma before returning.

Returns:

  • hkl (torch.Tensor) – Miller indices of shape (N, 3). Full size, unfiltered.

  • F (MaskedTensor) – Structure factor amplitudes of shape (N,) with invalid reflections masked.

  • F_sigma (MaskedTensor or None) – Uncertainties of shape (N,) with invalid reflections masked, or None.

  • rfree_flags (torch.Tensor or None) – R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.

Return type:

Tuple[Tensor, MaskedTensor, MaskedTensor, Tensor]

Notes

MaskedTensors:

  • Are PyTorch tensors with an associated boolean mask

  • Aggregations (sum, mean, etc.) skip masked values automatically

  • Element-wise operations preserve the mask

  • Use .get_data() and .get_mask() to access underlying data

  • Use .to_tensor(fill_value) to convert back to regular tensor

  • Note: MaskedTensor is in prototype stage in PyTorch

Loss functions and targets extract valid data from MaskedTensors before computation to work correctly with complex F_calc values.

Examples

Access reflection data with MaskedTensors:

hkl, F, sigma, rfree = data()
print(F.shape)  # Full shape
print(F.sum())  # Only sums valid (unmasked) values

# Access underlying data
valid_mask = F.get_mask()
F_values = F.get_data()[valid_mask]
data_fill_masked(mode='mean')[source]

Return data tensors with missing or flagged reflections filled with.

args:
mode: str, optional

‘mean’ : fill missing/flagged with mean of present data ‘zero’ : fill missing/flagged with zero

__getitem__(key)[source]

Index into the reflection dataset.

Parameters:

key (torch.Tensor) – Boolean mask or integer indices for selection.

Returns:

New ReflectionData object with selected reflections.

Return type:

ReflectionData

__select__(indices, op=None)[source]

Select reflections by boolean mask or integer indices.

Iterates over all dataclass fields generically, so new tensor fields are handled automatically.

Parameters:
  • indices (torch.Tensor) – Boolean mask of shape (N,) or integer indices for selection.

  • op (str, optional) – Operation name for tracking purposes.

Returns:

New ReflectionData object with selected reflections.

Return type:

ReflectionData

sanitize_F()[source]

Remove invalid values from structure factors.

Adds a mask to filter out NaN, Inf, and non-positive values from F and F_sigma.

check_all_data_types()[source]
validate_hkl(hkl_ref)[source]

Expand dataset to match a reference HKL set.

Reorders and expands the current dataset to match the reference HKL set. Reflections present in the reference but missing from this dataset are filled with placeholder values and masked out. This ensures all datasets aligned to the same reference have identical shapes and can be processed together without data loss from intersection operations.

Parameters:

hkl_ref (torch.Tensor) – Reference Miller indices tensor of shape (N, 3), dtype int32. This defines the canonical HKL ordering for all aligned datasets.

Returns:

Self, modified in-place with expanded arrays matching hkl_ref.

Return type:

ReflectionData

Notes

After calling this method: - self.hkl will equal hkl_ref exactly - All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded - Missing reflections are filled with 0 (or appropriate defaults) - A mask ‘hkl_present’ is added marking which reflections have real data - forward() will return MaskedTensors that skip missing reflections

This approach avoids the problem where intersecting many datasets with different outliers/missing reflections causes exponential data loss.

Examples

Align multiple datasets to a common HKL set:

reference_hkl = data1.hkl.clone()
data1.validate_hkl(reference_hkl)
data2.validate_hkl(reference_hkl)
# Now data1 and data2 have identical shapes
assert data1.hkl.shape == data2.hkl.shape
find_outliers(model, scaler, z_threshold=4.0)[source]

Identify outlier reflections based on log-ratio distribution.

Uses the fact that log(F_obs) - log(F_calc) should be normally distributed. Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.

Parameters:
  • model (ModelFT) – ModelFT object to compute structure factors.

  • scaler (Scaler) – Scaler object to scale calculated structure factors.

  • z_threshold (float, optional) – Z-score threshold to classify outliers. Default is 4.0.

Returns:

Boolean mask where True indicates outliers.

Return type:

torch.Tensor

get_log_ratio(model, scaler)[source]

Compute log-ratio between observed and calculated structure factors.

Parameters:
  • model (ModelFT) – ModelFT object to compute structure factors.

  • scaler (Scaler) – Scaler object to scale calculated structure factors.

Returns:

Log-ratio values: log(F_obs) - log(F_calc).

Return type:

torch.Tensor

get_outlier_statistics()[source]

Get statistics about flagged outliers.

Returns:

Dictionary containing: - n_outliers : int - n_total : int - fraction_outliers : float - detection_params : dict or None - outlier_resolution_stats : dict (if resolution available)

Return type:

dict

unpack_one()[source]

Unpack one level of source.

Does not recurse fully and does not flag.

Returns:

Parent source or self if no source.

Return type:

ReflectionData

flag_suspicious_sigma(z_threshold=5.0)[source]

Flag sigma values that deviate significantly from expected distribution.

Sigma values from a detector should follow a log-normal distribution. Values with z-scores beyond threshold are flagged as suspicious.

Parameters:

z_threshold (float, optional) – Z-score threshold for flagging outliers. Default is 5.0.

dump()[source]

Dump all reflection data to console for debugging.

Prints type, shape, and device information for all attributes.

write_mtz(fname, fcalc=None, model_ft=None)[source]

Write reflection data to MTZ file with optional map coefficients.

Parameters:
  • fname (str) – Output MTZ filename.

  • fcalc (torch.Tensor, optional) – Complex calculated structure factors of shape (N,). If provided, computes phases and map coefficients.

  • model_ft (ModelFT, optional) – ModelFT object to compute fcalc if not provided.

Notes

The MTZ file will contain canonical column names:
  • FP, SIGFP: Observed amplitudes and uncertainties

  • I, SIGI: Observed intensities and uncertainties (if available)

  • FreeR_flag: R-free test set flags

  • FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)

  • DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)

Map coefficients are computed as:
  • 2mFo-DFc: 2*Fo - Fc (filled to resolution limit)

  • mFo-DFc: Fo - Fc

Examples

Write MTZ with map coefficients:

data = ReflectionData().load_mtz('observed.mtz')
model = Model().load_pdb('model.pdb')
model_ft = ModelFT(model, data.cell, data.spacegroup)
fcalc = model_ft.forward(data.hkl)
data.write_mtz('output.mtz', fcalc=fcalc)
property centric

Get boolean mask for centric reflections (full size, unfiltered).

Calculates it if not already present. Returns unfiltered centric flags matching the full HKL array size, consistent with how forward() returns full-size arrays when using MaskedTensors.

Returns:

Boolean tensor of shape (N,) where N is total reflections. True indicates centric reflection, False indicates acentric.

Return type:

torch.Tensor or None

calc_patterson(grid_size=None, grid_sampling=1)[source]

Calculate Patterson map of the dataset.

The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw)) is computed via inverse FFT of F². Data is expanded to P1 symmetry using only observed reflections (no filling of missing data).

Parameters:
  • grid_size (tuple of int, optional) – Grid dimensions (Nx, Ny, Nz). If None, automatically determined from unit cell and resolution.

  • grid_sampling (float, optional) – Sampling interval for the grid. Default is 1. This sets the grid so that we sample twice as much as normal for a given resolution

Returns:

Real-valued Patterson map of shape (Nx, Ny, Nz). Origin is at grid position [0, 0, 0].

Return type:

torch.Tensor

possible_hkl()[source]

Generate all possible HKL indices within the resolution limit.

Returns:

Tensor of shape (M, 3) containing all possible Miller indices within the resolution limit defined by self.resolution.

Return type:

torch.Tensor

remap(new_hkl, index_mapping, phase_shifts=None, spacegroup=None, op_name='remap')[source]

Create new ReflectionData with remapped HKL set and data.

This is the core method for index-based transformations of reflection data. It handles remapping all tensor fields based on an index mapping, with support for missing reflections (indicated by -1 in index_mapping).

Parameters:
  • new_hkl (torch.Tensor, shape (M, 3)) – New Miller indices.

  • index_mapping (torch.Tensor, shape (M,), dtype int64) – Maps new indices to original: new[i] = old[index_mapping[i]] Values of -1 indicate missing reflections (filled with defaults).

  • phase_shifts (torch.Tensor, optional, shape (M,)) – Phase offsets to apply (e.g., from symmetry translations).

  • spacegroup (str, int, gemmi.SpaceGroup, or None) – New spacegroup. If None, keeps original.

  • op_name (str) – Operation name for provenance tracking.

Returns:

New object with remapped data. Missing reflections get: - 0.0 for F, I, phase, fom - 1.0 for F_sigma, I_sigma (conservative uncertainty) - True for masks[‘missing’]

Return type:

ReflectionData

fill(d_min=None)[source]

Fill missing reflections within resolution limit.

Generates all possible reflections for the current spacegroup within the resolution limit, identifies which are missing, and creates a complete dataset. Missing reflections are filled with default values.

Parameters:

d_min (float, optional) – High resolution limit in Angstroms. If None, uses the minimum resolution from the current dataset.

Returns:

New ReflectionData with complete set of reflections. Missing reflections have: - F, I, phase, fom = 0.0 - F_sigma, I_sigma = 1.0 - masks[‘missing’] = True

Return type:

ReflectionData

Examples

Fill to completeness:

data = ReflectionData().load_mtz('data.mtz')
data_filled = data.fill(d_min=2.0)
print(f"Original: {len(data)}, Filled: {len(data_filled)}")
expand_to_p1(include_friedel=True, remove_absences=True)[source]

Expand reflection data from asymmetric unit to P1.

Applies all symmetry operations from the current space group to generate all symmetry-equivalent reflections. Returns a NEW ReflectionData object with expanded reflections; does not modify self.

Parameters:
  • include_friedel (bool, default True) – Include Friedel mates (-h, -k, -l). For normal (non-anomalous) scattering, Friedel pairs have identical amplitudes.

  • remove_absences (bool, default True) – Remove systematically absent reflections from output.

Returns:

New ReflectionData object with: - All symmetry-equivalent reflections - spacegroup set to P1 - Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags) - Recalculated resolution values - Provenance tracking via source and last_op attributes

Return type:

ReflectionData

Notes

The expansion handles tensor fields as follows:

  • hkl: Symmetry operations applied, Friedel mates added, duplicates removed

  • F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original

  • phase: Indexed + phase shift from translation component

  • resolution: Recalculated from expanded hkl + cell

  • bin_indices: Cleared (invalidated by expansion)

Examples

Expand to P1 for calculations:

# Load data in original space group
data = ReflectionData().load_mtz('data.mtz')
print(f"Original: {len(data)} reflections, {data.spacegroup}")

# Expand to P1
data_p1 = data.expand_to_p1()
print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}")

# Without Friedel mates (for anomalous data)
data_p1_anom = data.expand_to_p1(include_friedel=False)
reduce_to_spacegroup(spacegroup, include_friedel=True, aggregation='mean')[source]

Reduce P1 reflection data to asymmetric unit of a target spacegroup.

This is the inverse of expand_to_p1(). Takes reflection data in P1 and merges symmetry-equivalent reflections into single ASU reflections using the specified aggregation function.

Parameters:
  • spacegroup (str, int, or gemmi.SpaceGroup) – Target space group specification.

  • include_friedel (bool, default True) – If True, also merge Friedel mates when reducing.

  • aggregation (str, default 'mean') – Aggregation function for merging equivalent reflections: - ‘mean’: Average values (default, good for amplitudes) - ‘sum’: Sum values - ‘first’: Take first valid value (no averaging)

Returns:

New ReflectionData with merged reflections in the target spacegroup.

Return type:

ReflectionData

Notes

Field handling during reduction:

  • F, I: Aggregated using specified function

  • F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for ‘mean’

  • phase: Aggregated via complex averaging (handles phase wrapping)

  • fom: Weighted by amplitude during phase averaging

  • rfree_flags: OR operation (True if any equivalent is True)

Examples

Reduce symmetry-expanded data:

# Expand to P1 for calculations, then reduce back
data_p1 = data.expand_to_p1()
# ... modify F_p1 ...
data_merged = data_p1.reduce_to_spacegroup('P21')

# Reduce with sum instead of mean
data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
canonicalize(include_friedel=True)[source]

Return new ReflectionData with HKL in standard CCP4 ASU form.

Remaps all Miller indices to the canonical CCP4 asymmetric unit representative using gemmi.ReciprocalAsu, adjusts phases accordingly, and sorts reflections lexicographically by (h, k, l).

Parameters:

include_friedel (bool, default True) – Whether Friedel mates are considered equivalent.

Returns:

New object with canonicalized, sorted Miller indices.

Return type:

ReflectionData

get_scattering_vectors()[source]

Get scattering vectors (s-vectors) from hkl and cell.

The s-vector for a reflection hkl is defined as:

s = B* @ hkl

where B* is the reciprocal basis matrix.

Returns:

s_vectors – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

Return type:

torch.Tensor

Raises:

ValueError – If hkl or cell is not available.

get_radial_shells(n_shells=20, d_min=None, d_max=None)[source]

Create uniform radial shells in 1/d space for normalization.

This creates shells with uniform spacing in reciprocal space (1/d), different from get_bins() which creates equal-count bins.

Parameters:
  • n_shells (int) – Number of radial shells. Default is 20.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

Returns:

  • shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).

  • shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).

  • shell_indices (torch.Tensor) – Shell index for each reflection, shape (N,). Values -1 for out-of-range.

Return type:

Tuple[Tensor, Tensor, Tensor]

fit_anisotropy(n_shells=20, d_min=None, d_max=None, n_iterations=100, verbose=None)[source]

Fit anisotropy correction parameters to minimize CV within shells.

Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell.

Parameters:
  • n_shells (int) – Number of resolution shells for variance calculation.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

  • n_iterations (int) – Number of optimization iterations.

  • verbose (bool, optional) – Print progress. If None, uses self.verbose.

Returns:

U – Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,). Also stored in self.U_aniso.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is available.

setup_anisotropy(U_aniso=None)[source]

Setup anisotropy correction parameters. :param U_aniso: Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).

If None, uses Initializes as zeros.

apply_anisotropy_correction(U_aniso=None)[source]

Apply anisotropy correction to F² values.

Parameters:

U_aniso (torch.Tensor, optional) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,). If None, uses self.U_aniso (must have called fit_anisotropy first).

Returns:

  • F_corrected (torch.Tensor) – Anisotropy-corrected F values, shape (N,).

  • sigma_F_corrected (torch.Tensor) – Uncertainties of corrected F values, shape (N,).

Raises:

ValueError – If no U parameters available and none provided.

Return type:

Tensor

compute_e_values(n_shells=20, d_min=None, d_max=None, apply_anisotropy=True, fit_anisotropy=True, verbose=None)[source]

Compute E-values with optional anisotropy correction.

E-values are normalized structure factors where <E²> = 1 within each resolution shell. Anisotropy correction can be applied first to account for directional variation in diffraction.

Parameters:
  • n_shells (int) – Number of resolution shells for normalization.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

  • apply_anisotropy (bool) – If True, apply anisotropy correction before E-value calculation.

  • fit_anisotropy (bool) – If True and apply_anisotropy is True, fit anisotropy parameters. If False and apply_anisotropy is True, uses existing self.U_aniso.

  • verbose (bool, optional) – Print progress. If None, uses self.verbose.

Returns:

E – E-values, shape (N,). Also stored in self.E. self.E_squared is also populated with E² values.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is available.

Examples

Compute E-values with automatic anisotropy correction:

data = ReflectionData().load_mtz('data.mtz')
E = data.compute_e_values(n_shells=30)
print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")

Compute E-values without anisotropy correction:

E = data.compute_e_values(apply_anisotropy=False)
setup_scale(scale=None)[source]

Set overall scale factor, parametrized in log space.

Parameters:

scale (float, optional) – If provided, sets the scale factor directly. If None, computes scale to make mean F equal to 1.0.

Returns:

The scale factor applied.

Return type:

float

get_corrected_data()[source]

Get scaled amplitude and uncertainty tensors.

Applies the current scale factor (from self.log_scale) to F and F_sigma.

Returns:

Scaled F and F_sigma tensors. If log_scale is None, returns original.

Return type:

Tuple[torch.Tensor, torch.Tensor]

parameters()[source]

Get list of learnable parameters for optimization.

Returns:

iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.

Return type:

iter[Parameter]

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)
class torchref.io.DatasetCollection(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)[source]

Bases: CrystalDataset

Container for multiple related crystal datasets.

All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.

  • device (str, optional) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.

hkl

Common HKL set for all datasets.

Type:

torch.Tensor

n_datasets

Number of datasets in collection.

Type:

int

Examples

from torchref.io import DatasetCollection, ReflectionData

collection = DatasetCollection(device='cuda')

native = ReflectionData().load_mtz('native.mtz')
derivative = ReflectionData().load_mtz('derivative.mtz')

collection.add_dataset('native', native, set_as_reference=True)
collection.add_dataset('derivative', derivative)

for name, dataset in collection:
    print(f"{name}: {len(dataset)} reflections")

# Access by name
native_F = collection['native'].F
add_dataset(name, dataset, set_as_reference=False)[source]

Add a dataset to the collection.

Parameters:
  • name (str) – Identifier for this dataset.

  • dataset (ReflectionData) – The dataset to add.

  • set_as_reference (bool, optional) – If True, use this dataset’s HKL as the reference. Default is False, but the first dataset added automatically becomes the reference.

Returns:

Self, for method chaining.

Return type:

DatasetCollection

Raises:

ValueError – If a dataset with the same name already exists.

Examples

collection = DatasetCollection()
collection.add_dataset('native', native_data, set_as_reference=True)
collection.add_dataset('derivative', derivative_data)
property hkl: Tensor | None

Common HKL set for all datasets.

property datasets: Dict[str, ReflectionData]

Access all datasets as a dictionary.

property n_datasets: int

Number of datasets in collection.

property reference_dataset: str | None

Name of the reference dataset.

property spacegroup: str | None

Space group of the reference dataset.

__getitem__(name)[source]

Get dataset by name.

Parameters:

name (str) – Name of the dataset.

Returns:

The requested dataset.

Return type:

ReflectionData

Raises:

KeyError – If dataset name not found.

__iter__()[source]

Iterate over (name, dataset) pairs in order of addition.

Yields:

tuple of (str, ReflectionData) – Name and dataset for each dataset in collection.

__len__()[source]

Number of reflections in common HKL set.

__contains__(name)[source]

Check if dataset exists in collection.

__call__(mask=True)[source]

Return all datasets’ data scaled if scale factors are set.

Parameters:

mask (bool, optional) – Whether to apply masking. Default is True.

Returns:

Dictionary mapping name to (hkl, F, F_sigma, rfree) tuples.

Return type:

dict

scale()[source]

Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps.

The collection instance, allowing for method chaining.

Raises:

ValueError – If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh…

Notes

The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. “”” Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset.

Returns:

for method chaining.

Return type:

self

keys()[source]

Return list of dataset names.

values()[source]

Return list of datasets.

items()[source]

Return list of (name, dataset) tuples.

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)
get(name, default=None)[source]

Get dataset by name with default fallback.

__repr__()[source]

String representation of collection.

class torchref.io.FcalcDataset(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)[source]

Bases: CrystalDataset

Dataset for storing calculated structure factors.

Provides a lightweight container for Fcalc values with: - Cell and spacegroup information (using torchref.symmetry types) - HKL indices and resolution - Complex Fcalc with amplitude/phase decomposition - MTZ export capability

This class inherits from CrystalDataset and overrides the spacegroup field to store torchref.symmetry.SpaceGroup instead of gemmi.SpaceGroup.

Parameters:
  • hkl (torch.Tensor, optional) – Miller indices of shape (N, 3).

  • resolution (torch.Tensor, optional) – Resolution per reflection of shape (N,).

  • cell (Cell, optional) – Unit cell object.

  • spacegroup (SpaceGroup, optional) – Space group object (torchref.symmetry.SpaceGroup).

  • fcalc (torch.Tensor, optional) – Complex structure factors of shape (N,).

  • fcalc_amp (torch.Tensor, optional) – Amplitudes |Fcalc| of shape (N,).

  • fcalc_phase (torch.Tensor, optional) – Phases in radians of shape (N,).

  • device (torch.device) – Device for tensors.

Examples

Create from cell and resolution:

from torchref.io.datasets import FcalcDataset

dataset = FcalcDataset.from_cell_and_resolution(
    cell=[50.0, 60.0, 70.0, 90.0, 90.0, 90.0],
    spacegroup='P212121',
    d_min=2.0,
)

# Set Fcalc values (complex tensor)
fcalc = torch.randn(len(dataset), dtype=torch.complex64)
dataset.set_fcalc(fcalc)

# Write to MTZ
dataset.write_mtz('output.mtz')
spacegroup: SpaceGroup | None = None
fcalc: Tensor | None = None
fcalc_amp: Tensor | None = None
fcalc_phase: Tensor | None = None
static from_cell_and_resolution(cell, spacegroup, d_min=2.0, d_max=None, device=device(type='cpu'), dtype=torch.float32)[source]

Create FcalcDataset with HKL generated to given resolution.

Parameters:
  • cell (torch.Tensor, list, or Cell) – Unit cell [a, b, c, alpha, beta, gamma] or Cell object.

  • spacegroup (SpaceGroupLike) – Space group (str, int, gemmi.SpaceGroup, or torchref.symmetry.SpaceGroup).

  • d_min (float, optional) – High resolution limit in Angstroms. Default is 2.0.

  • d_max (float, optional) – Low resolution limit in Angstroms. If provided, reflections with d-spacing > d_max are removed.

  • device (torch.device) – Target device.

  • dtype (torch.dtype) – Float dtype for tensors.

Returns:

New dataset with HKL and resolution populated.

Return type:

FcalcDataset

Examples

from torchref.symmetry import Cell, SpaceGroup

cell = Cell([50.0, 60.0, 70.0, 90.0, 90.0, 90.0])
sg = SpaceGroup('P212121')
dataset = FcalcDataset.from_cell_and_resolution(
    cell=cell, spacegroup=sg, d_min=2.0,
)
print(f"Generated {len(dataset)} reflections")
set_fcalc(fcalc)[source]

Assign complex Fcalc values.

Automatically computes amplitude and phase from complex values.

Parameters:

fcalc (torch.Tensor) – Complex structure factors with shape (N,).

Raises:

ValueError – If fcalc length doesn’t match HKL length.

Examples

# Create complex Fcalc values
fcalc = torch.randn(len(dataset), dtype=torch.complex64)
dataset.set_fcalc(fcalc)

print(dataset.fcalc_amp[:5])   # Amplitudes
print(dataset.fcalc_phase[:5]) # Phases in radians
write_mtz(filepath)[source]

Write Fcalc to MTZ file.

Parameters:

filepath (str) – Output MTZ filename.

Raises:

ValueError – If no Fcalc values have been set.

Examples

dataset.set_fcalc(fcalc_values)
dataset.write_mtz('calculated.mtz')
write_mtz_as_fobs(filepath, sigma_frac=0.05, f_column='F-obs', sigf_column='SIGF-obs', phase_column='PHIF-model')[source]

Write Fcalc to MTZ as if it were observed data (F-obs columns).

Useful for creating simulated “experimental” MTZ files that can be read back by ReflectionData.load_mtz() as observed amplitudes.

Parameters:
  • filepath (str) – Output MTZ filename.

  • sigma_frac (float, optional) – Sigma as a fraction of |F|. Default is 0.05 (5%).

  • f_column (str, optional) – Column name for amplitudes. Default is ‘F-obs’.

  • sigf_column (str, optional) – Column name for sigma. Default is ‘SIGF-obs’.

  • phase_column (str, optional) – Column name for model phases. Default is ‘PHIF-model’.

Examples

dataset.set_fcalc(fcalc_values)
dataset.write_mtz_as_fobs('simulated_obs.mtz', sigma_frac=0.05)
__repr__()[source]

String representation of dataset.

property spacegroup_name: str | None

Get space group name as string (short form, e.g., ‘P212121’).

property spacegroup_hm: str | None

Get space group Hermann-Mauguin name with spaces (e.g., ‘P 21 21 21’).

property spacegroup_number: int | None

Get space group number (1-230).

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, fcalc=None, fcalc_amp=None, fcalc_phase=None)
class torchref.io.MTZReader(verbose=0, column_names=None)[source]

Bases: object

Reader for MTZ files containing crystallographic structure factor data.

This class reads MTZ files using reciprocalspaceship and extracts: - Miller indices (h, k, l) - Structure factor amplitudes or intensities - Associated uncertainties (sigma values) - R-free test set flags

verbose

Verbosity level for logging (0=silent, 1=normal, 2=debug).

Type:

int

data

Dictionary containing extracted data arrays.

Type:

dict

cell

Unit cell parameters [a, b, c, alpha, beta, gamma].

Type:

np.ndarray

spacegroup

Space group object.

Type:

gemmi.SpaceGroup

Examples

reader = mtz.read('data.mtz', verbose=1)
data_dict, cell, spacegroup = reader()
print(f"Found {len(data_dict['HKL'])} reflections in {spacegroup.short_name()}")
AMPLITUDE_PRIORITY = ['F-obs', 'FOBS', 'FP', 'F', 'F-obs-filtered', 'FOBS-filtered', 'F(+)', 'FPLUS', 'FMEAN', 'F-pk', 'F_pk', 'FO', 'FODD', 'F-model', 'FC', 'FCALC']
INTENSITY_PRIORITY = ['I-obs', 'IOBS', 'I', 'IMEAN', 'I-obs-filtered', 'IOBS-filtered', 'I(+)', 'IPLUS', 'IP', 'I-pk', 'I_pk', 'IHLI', 'I_full', 'IOBS_full', 'IO']
RFREE_FLAG_NAMES = ['R-free-flags', 'RFREE', 'FreeR_flag', 'FREE', 'R-free', 'Rfree', 'FREER', 'FREE_FLAG', 'test', 'TEST', 'free', 'Free']
__init__(verbose=0, column_names=None)[source]

Initialize MTZ reader.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 0.

  • column_names (dict, optional) – Explicit column name mapping to override automatic detection. Supported keys: "F", "SIGF", "I", "SIGI". Example: {"F": "DFo", "SIGF": "sig_DFo"}.

read(filepath)[source]

Read an MTZ file and extract reflection data.

Parameters:

filepath (str) – Path to the MTZ file.

Returns:

Self, for method chaining.

Return type:

MTZReader

__call__()[source]

Return extracted data in a standardized format.

Returns:

  • data (dict) – Dictionary with extracted data arrays.

  • cell (np.ndarray) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • spacegroup (str) – Space group name string.

Return type:

Tuple[dict, ndarray, str]

class torchref.io.PDBReader(verbose=0)[source]

Bases: object

Reader for PDB files containing atomic coordinate data.

This class reads PDB files and extracts atomic coordinates, properties, and crystallographic metadata.

verbose

Verbosity level for logging.

Type:

int

dataframe

DataFrame containing atomic data.

Type:

pd.DataFrame

cell

Unit cell parameters [a, b, c, alpha, beta, gamma].

Type:

list or None

spacegroup

Space group symbol.

Type:

str or None

Examples

reader = pdb.read('structure.pdb', verbose=1)
df, cell, spacegroup = reader()
print(f"Loaded {len(df)} atoms")
__init__(verbose=0)[source]

Initialize PDB reader.

Parameters:

verbose (int, optional) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 0.

read(filepath)[source]

Read a PDB file and extract atomic data.

Parameters:

filepath (str) – Path to the PDB file.

Returns:

Self, for method chaining.

Return type:

PDBReader

__call__()[source]

Return extracted data in a standardized format.

Returns:

  • dataframe (pd.DataFrame) – DataFrame with atomic data.

  • cell (np.ndarray or None) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • spacegroup (str or None) – Space group symbol.

Return type:

Tuple[DataFrame, ndarray | None, str | None]

class torchref.io.CIFReader(filepath=None, data_block=None, parse_all_blocks=False)[source]

Bases: object

A dictionary-like reader for CIF/mmCIF files.

Loops are stored as pandas DataFrames. Other data is stored in a hierarchical dictionary structure.

Parameters:
  • filepath (str, optional) – Path to CIF file to load immediately.

  • data_block (str, optional) – Specific data block name to read (e.g., ‘r1vlmsf’). If None and parse_all_blocks=False, reads the first data block. If None and parse_all_blocks=True, reads all data blocks.

  • parse_all_blocks (bool, default False) – If True, parse all data blocks and merge them into a single dictionary (useful for restraint files). If False, parse only the specified block or the first block.

data

Dictionary storing parsed CIF data.

Type:

dict

filepath

Path to the loaded CIF file.

Type:

Path or None

available_blocks

List of data block names found in the file.

Type:

list

__init__(filepath=None, data_block=None, parse_all_blocks=False)[source]

Initialize CIF reader.

Parameters:
  • filepath (str, optional) – Path to CIF file to load immediately.

  • data_block (str, optional) – Specific data block name to read.

  • parse_all_blocks (bool, default False) – If True, parse all data blocks and merge.

classmethod from_string(content, **kwargs)[source]

Create CIFReader from string content instead of a file.

load(filepath)[source]

Load and parse a CIF file.

Parameters:

filepath (str) – Path to CIF file.

write(filepath)[source]

Write the CIF data back to a file.

Parameters:

filepath (str) – Output file path.

__getitem__(key)[source]

Get item by key.

__setitem__(key, value)[source]

Set item by key.

__contains__(key)[source]

Check if key exists.

__len__()[source]

Return number of top-level categories.

keys()[source]

Return dictionary keys.

values()[source]

Return dictionary values.

items()[source]

Return dictionary items.

get(key, default=None)[source]

Get item with default value.

__repr__()[source]

String representation.

summary()[source]

Print a summary of the CIF contents.

class torchref.io.ReflectionCIFReader(filepath, verbose=0, data_block=None)[source]

Bases: object

Reader for structure factor CIF files (e.g., *-sf.cif from PDB).

Handles extraction of: - Miller indices (h, k, l) - Structure factor amplitudes (F) and uncertainties (σF) - Intensities (I) and uncertainties (σI) - Phases and figures of merit - R-free flags - Unit cell and space group metadata

Compatible with legacy MTZ reader interface:

reader = ReflectionCIFReader(‘7JI4-sf.cif’).read() data_dict, spacegroup, cell = reader()

Example:

reader = ReflectionCIFReader(‘7JI4-sf.cif’) refln_data = reader.get_reflection_data() h, k, l = refln_data[‘h’], refln_data[‘k’], refln_data[‘l’] F_obs = refln_data[‘F_obs’]

__init__(filepath, verbose=0, data_block=None)[source]

Initialize and load structure factor CIF file.

Parameters:
  • filepath (str) – Path to structure factor CIF file.

  • verbose (int, default 0) – Verbosity level (0=silent, 1=info, 2=debug).

  • data_block (str, optional) – Specific data block name to read (e.g., ‘r1vlmsf’). If None, reads the first data block. Useful for files with multiple datasets.

read(filepath=None)[source]

Read a CIF file (for compatibility with legacy interface).

Args:

filepath: Path to CIF file (optional, uses initialization path if not provided)

Returns:

self for method chaining

__call__()[source]

Get data in legacy MTZ-compatible format.

Returns:

  • data (dict) – Dictionary with extracted data arrays: - ‘h’, ‘k’, ‘l’: Miller indices - ‘F’, ‘SIGF’: Amplitudes and sigmas (if available) - ‘I’, ‘SIGI’: Intensities and sigmas (if available) - ‘R-free-flags’: R-free test set flags (if available)

  • cell (numpy.ndarray) – Cell parameters [a, b, c, alpha, beta, gamma].

  • spacegroup (gemmi.SpaceGroup) – Space group object.

Return type:

Tuple[Dict[str, ndarray], ndarray, SpaceGroup]

get_reflection_data()[source]

Extract reflection data with standardized column names.

Returns:

DataFrame with columns: - h, k, l: Miller indices - F_obs, sigma_F_obs: Observed amplitudes (if available) - I_obs, sigma_I_obs: Observed intensities (if available) - phase, fom: Phase and figure of merit (if available) - free_flag: R-free flags (if available)

Return type:

pandas.DataFrame

Notes

Missing columns will be filled with NaN or appropriate defaults.

has_miller_indices()[source]

Check if file contains Miller indices.

has_amplitudes()[source]

Check if file contains structure factor amplitudes.

has_intensities()[source]

Check if file contains intensity measurements.

has_phases()[source]

Check if file contains phase information.

has_rfree_flags()[source]

Check if file contains R-free flags.

get_miller_indices()[source]

Get Miller indices as Nx3 array.

Returns:

Array of shape (N, 3) with h, k, l indices

get_amplitudes()[source]

Get structure factor amplitudes and uncertainties.

Returns:

Dict with keys ‘F’ and ‘sigma_F’, or None if not available

get_intensities()[source]

Get intensities and uncertainties.

Returns:

Dict with keys ‘I’ and ‘sigma_I’, or None if not available

get_cell_parameters()[source]

Extract unit cell parameters [a, b, c, alpha, beta, gamma].

Returns:

List of 6 floats, or None if not found

get_space_group()[source]

Extract space group name.

Returns:

Space group name string. Returns “P 1” if not found.

Return type:

str

class torchref.io.ModelCIFReader(filepath, verbose=0)[source]

Bases: object

Reader for model/structure CIF files (e.g., *.cif from PDB).

Handles extraction of: - Atomic coordinates and properties - Alternative conformations - Anisotropic displacement parameters - Unit cell and space group

Compatible with legacy PDB reader interface:

reader = ModelCIFReader(‘3E98.cif’).read() dataframe, cell, spacegroup = reader()

Example:

reader = ModelCIFReader(‘3E98.cif’) atom_df = reader.get_atom_data() cell = reader.get_cell_parameters()

__init__(filepath, verbose=0)[source]

Initialize and load model CIF file.

Parameters:
  • filepath (str) – Path to model CIF file.

  • verbose (int, default 0) – Verbosity level (0=silent, 1=info, 2=debug).

read(filepath=None)[source]

Read a CIF file (for compatibility with legacy interface).

Parameters:

filepath (str, optional) – Path to CIF file. Uses initialization path if not provided.

Returns:

Self for method chaining.

Return type:

ModelCIFReader

__call__()[source]

Get data in legacy PDB-compatible format.

Returns:

  • dataframe (pandas.DataFrame) – Atom data with columns: ATOM, serial, name, altloc, resname, chainid, resseq, icode, x, y, z, occupancy, tempfactor, element, charge, anisou_flag, u11, u22, u33, u12, u13, u23.

  • cell (list) – Cell parameters [a, b, c, alpha, beta, gamma].

  • spacegroup (gemmi.SpaceGroup) – Space group object.

Return type:

Tuple[DataFrame, List[float], SpaceGroup]

get_atom_data()[source]

Extract atomic coordinate data in PDB-compatible format.

Returns:

DataFrame with columns matching PDB format: - ATOM, serial, name, altloc, resname, chainid, resseq, icode - x, y, z, occupancy, tempfactor - element, charge - anisou_flag, u11, u22, u33, u12, u13, u23

Return type:

pandas.DataFrame

get_atom_data_by_model()[source]

Split atom data by pdbx_PDB_model_num.

For single-model files, returns {1: dataframe}. For multi-model files, returns one DataFrame per model number.

Returns:

Mapping of model number to atom DataFrame.

Return type:

dict of int -> pandas.DataFrame

get_cell_parameters()[source]

Extract unit cell parameters [a, b, c, alpha, beta, gamma].

get_space_group()[source]

Extract space group name.

Returns:

Space group name string. Returns “P 1” if not found.

Return type:

str

has_coordinates()[source]

Check if atomic coordinates are available.

has_cell_parameters()[source]

Check if unit cell parameters are available.

has_space_group()[source]

Check if space group information is available.

has_occupancy()[source]

Check if occupancy data is available.

has_bfactor()[source]

Check if B-factor/temperature factor data is available.

has_anisotropic_data()[source]

Check if anisotropic displacement parameters are available.

get_coordinates()[source]

Extract atomic coordinates as numpy array.

Returns:

Nx3 array of [x, y, z] coordinates, or None if not available.

Return type:

numpy.ndarray or None

get_atom_info()[source]

Extract atom information (without coordinates).

Returns:

DataFrame with atom names, residue info, elements, etc.

Return type:

pandas.DataFrame

class torchref.io.RestraintCIFReader(filepath)[source]

Bases: object

Reader for chemical restraint dictionary CIF files (e.g., from monomer library).

Handles extraction of: - Bond restraints (ideal lengths and ESDs) - Angle restraints - Torsion/dihedral restraints - Planarity restraints - Chirality definitions

Validates that the file contains proper restraint parameters (not just structure definitions).

Example:

reader = RestraintCIFReader(‘external_monomer_library/a/ALA.cif’) comp_data = reader.get_all_restraints() bond_df = comp_data[‘ALA’][‘bonds’]

__init__(filepath)[source]

Initialize and load restraint CIF file.

Parameters:

filepath (str) – Path to restraint dictionary CIF file.

get_all_restraints()[source]

Extract all restraint data for all compounds with standardized column names.

Returns:

Dictionary mapping compound ID to dict of restraint types:

{
    'ALA': {
        'bonds': DataFrame(atom1, atom2, value, sigma),
        'angles': DataFrame(atom1, atom2, atom3, value, sigma),
        'torsions': DataFrame(atom1, atom2, atom3, atom4, value, sigma, periodicity),
        'planes': DataFrame(atom, plane_id),
        'chirals': DataFrame(atom_centre, atom1, atom2, atom3, volume_sign)
    },
    ...
}

Return type:

dict

get_compound_restraints(comp_id)[source]

Extract restraints for a specific compound with standardized column names.

Parameters:

comp_id (str) – Compound identifier (e.g., ‘ALA’).

Returns:

Dictionary of restraint DataFrames with standardized columns:

{
    'bonds': DataFrame(atom1, atom2, value, sigma)
    'angles': DataFrame(atom1, atom2, atom3, value, sigma)
    'torsions': DataFrame(atom1, atom2, atom3, atom4, value, sigma, periodicity)
    'planes': DataFrame(atom, plane_id)
    'chirals': DataFrame(atom_centre, atom1, atom2, atom3, volume_sign)
    'atoms': DataFrame(atom_id, type_symbol, charge, etc.)
}

Return type:

dict

get_bond_restraints(comp_id)[source]

Get bond restraints with standardized column names.

Returns:
DataFrame with columns:
  • atom1, atom2: Atom names

  • value: Ideal bond length (Å)

  • sigma: Estimated standard deviation (Å)

get_compound_id()[source]

Get the primary compound ID from this file.

has_bond_restraints()[source]

Check if bond restraints are available.

has_angle_restraints()[source]

Check if angle restraints are available.

has_torsion_restraints()[source]

Check if torsion restraints are available.

has_plane_restraints()[source]

Check if plane restraints are available.

has_chirality_restraints()[source]

Check if chirality definitions are available.

class torchref.io.DataRouter(filepath, verbose=1)[source]

Bases: object

Automatic file type detection and reader selection.

This class examines a file and automatically selects the appropriate reader based on file extension and content.

Parameters:
  • filepath (str or Path) – Path to the file to read.

  • verbose (int, default 1) – Verbosity level (0=quiet, 1=normal, 2+=debug).

filepath

Path to the file to read.

Type:

Path

verbose

Verbosity level for logging.

Type:

int

data_type

Type of data detected (‘reflections’, ‘structure’, ‘restraints’, ‘ihm_ensemble’, or None).

Type:

str or None

file_format

File format detected (‘mtz’, ‘pdb’, ‘cif’, or None).

Type:

str or None

reader

The appropriate reader instance (or None if not yet created).

Type:

object or None

Examples

router = DataRouter("structure.cif")
reader = router.get_reader()
print(router.data_type)  # 'structure'
MTZ_EXTENSIONS = {'.mtz'}
PDB_EXTENSIONS = {'.ent', '.pdb'}
CIF_EXTENSIONS = {'.cif', '.mmcif'}
__init__(filepath, verbose=1)[source]

Initialize the DataRouter.

Parameters:
  • filepath (str or Path) – Path to the data file.

  • verbose (int, default 1) – Verbosity level (0=quiet, 1=normal, 2+=debug).

get_reader()[source]

Get the appropriate reader for this file.

Returns:

Reader instance (ReflectionCIFReader, ModelCIFReader, RestraintCIFReader, MTZ, or PDB depending on file type).

Return type:

object

Raises:

DataRouterError – If file type is not supported or cannot be determined.

get_data()[source]

Get the data from the file using the appropriate reader.

This is a convenience method that calls get_reader() and then invokes the reader to get the data.

Returns:

For reflections: (data_dict, cell, spacegroup) For structure: (dataframe, residues, spacegroup) For restraints: Restraint data (format depends on reader)

Return type:

tuple

classmethod route(filepath, verbose=1)[source]

Factory method to quickly route a file to the appropriate reader.

Parameters:
  • filepath (str or Path) – Path to the data file.

  • verbose (int, default 1) – Verbosity level.

Returns:

Tuple of (reader, data_type) where: - reader: The appropriate reader instance - data_type: String indicating the type (‘reflections’, ‘structure’, ‘restraints’)

Return type:

tuple

Examples

reader, data_type = DataRouter.route("7JI4-sf.cif")
if data_type == 'reflections':
    data_dict, cell, spacegroup = reader()
__repr__()[source]

String representation of the DataRouter.

__str__()[source]

User-friendly string representation.

exception torchref.io.DataRouterError[source]

Bases: Exception

Exception raised when file type cannot be determined or is unsupported.

class torchref.io.IHMEnsembleMapping(states=<factory>, model_groups=<factory>, cell=None, spacegroup=None, atom_data_per_state=None)[source]

Bases: object

Complete mapping between IHM mmCIF categories and torchref structures.

This is the central interchange object: both IHMReader and IHMWriter operate through it. It can also be constructed manually for programmatic workflows (e.g., building an IHM file from a KineticRefinement result without reading one first).

Parameters:
  • states (List[IHMStateInfo]) – Structural states (one per base model in ModelCollection).

  • model_groups (List[IHMModelGroupInfo]) – Model groups / timepoints (one per timepoint in ModelCollection).

  • cell (list of float, optional) – Unit cell parameters [a, b, c, alpha, beta, gamma].

  • spacegroup (str, optional) – Space group name (Hermann-Mauguin notation).

  • atom_data_per_state (dict, optional) – Mapping of state_id -> pandas DataFrame with atom data. Populated by IHMReader.read_atom_data().

states: List[IHMStateInfo]
model_groups: List[IHMModelGroupInfo]
cell: List[float] | None = None
spacegroup: str | None = None
atom_data_per_state: Dict[int, DataFrame] | None = None
get_state_ids()[source]

Return state IDs ordered by state_id.

get_timepoint_names()[source]

Return model group names ordered by group_id.

get_fractions_for_group(group_name)[source]

Return population fractions for a model group, ordered by state_id.

Parameters:

group_name (str) – Name of the model group.

Returns:

Fractions ordered by ascending state_id.

Return type:

list of float

Raises:

KeyError – If no group with the given name exists.

identify_dark_group()[source]

Heuristic: identify the reference / dark group.

Returns the name of the first model group where a single state has population fraction >= 0.95, or None if no such group exists.

get_state_by_id(state_id)[source]

Look up a state by its ID.

Raises:

KeyError – If no state with the given ID exists.

get_group_by_name(name)[source]

Look up a model group by name.

Raises:

KeyError – If no group with the given name exists.

validate()[source]

Check internal consistency.

Raises:

ValueError – If states are empty, fractions don’t reference valid states, or fractions don’t sum to ~1.0 for any group.

__init__(states=<factory>, model_groups=<factory>, cell=None, spacegroup=None, atom_data_per_state=None)
class torchref.io.IHMStateInfo(state_id, name, details='', model_num=1)[source]

Bases: object

Metadata for a single structural state (e.g., ground state, intermediate).

Parameters:
  • state_id (int) – Unique identifier matching _ihm_multi_state_modeling.state_id.

  • name (str) – Human-readable name (e.g., "ground_state", "intermediate_1").

  • details (str) – Free-text description of this state.

  • model_num (int) – pdbx_PDB_model_num in the _atom_site loop that corresponds to this state’s coordinates.

state_id: int
name: str
details: str = ''
model_num: int = 1
__init__(state_id, name, details='', model_num=1)
class torchref.io.IHMModelGroupInfo(group_id, name, state_fractions=<factory>, time_delay=None, time_delay_units='s')[source]

Bases: object

Metadata for a model group (experimental condition / timepoint).

Parameters:
  • group_id (int) – Unique identifier matching _ihm_model_group.id.

  • name (str) – Human-readable name (e.g., "dark", "1ps", "5ps").

  • state_fractions (Dict[int, float]) – Mapping of state_id -> population fraction for this group. Fractions should sum to 1.0.

  • time_delay (float, optional) – Time delay in time_delay_units (for time-resolved experiments).

  • time_delay_units (str) – Units for time_delay. Default "s" (seconds).

group_id: int
name: str
state_fractions: Dict[int, float]
time_delay: float | None = None
time_delay_units: str = 's'
__init__(group_id, name, state_fractions=<factory>, time_delay=None, time_delay_units='s')
class torchref.io.RefinementMetadata(program='TORCHREF', program_version='', refinement_method='', resolution_high=None, resolution_low=None, n_reflections_work=None, n_reflections_test=None, n_reflections_all=None, percent_free=None, r_work=None, r_free=None, b_mean_overall=None, b_min=None, b_max=None, rmsd_bond_lengths=None, rmsd_bond_angles=None, n_atoms_total=None, n_atoms_protein=None, n_atoms_solvent=None, solvent_model_ksol=None, solvent_model_bsol=None, cell=None, spacegroup=None, title='', authors=<factory>, passthrough_pdb_remarks=<factory>, passthrough_cif_categories=<factory>, custom_remarks=<factory>)[source]

Bases: object

Unified metadata for PDB headers and mmCIF categories.

Fields map to both PDB REMARK 3 lines and PDBx/mmCIF _refine category items. Only populated (non-None) fields are rendered.

Parameters:
  • program (str) – Refinement program name.

  • program_version (str) – Program version string.

program: str = 'TORCHREF'
program_version: str = ''
refinement_method: str = ''
resolution_high: float | None = None
resolution_low: float | None = None
n_reflections_work: int | None = None
n_reflections_test: int | None = None
n_reflections_all: int | None = None
percent_free: float | None = None
r_work: float | None = None
r_free: float | None = None
b_mean_overall: float | None = None
b_min: float | None = None
b_max: float | None = None
rmsd_bond_lengths: float | None = None
rmsd_bond_angles: float | None = None
n_atoms_total: int | None = None
n_atoms_protein: int | None = None
n_atoms_solvent: int | None = None
solvent_model_ksol: float | None = None
solvent_model_bsol: float | None = None
cell: List[float] | None = None
spacegroup: str | None = None
title: str = ''
authors: List[str]
passthrough_pdb_remarks: List[str]
passthrough_cif_categories: Dict[str, Any]
custom_remarks: List[str]
to_dict()[source]

Serialize to a JSON-compatible dictionary, dropping None values.

classmethod from_dict(d)[source]

Reconstruct from a dictionary (inverse of to_dict).

classmethod from_refinement(refinement)[source]

Extract metadata from a completed Refinement object.

Reuses existing statistics from collect_metrics(), get_rfactor(), and reflection data attributes. Silently skips any unavailable statistics.

Parameters:

refinement (torchref.refinement.Refinement) – A refinement object (after refinement is complete).

classmethod from_pdb_file(filepath)[source]

Extract header metadata from an existing PDB file.

Captures TITLE, AUTHOR, and REMARK records for pass-through.

classmethod from_cif_file(filepath)[source]

Extract refinement metadata from an existing mmCIF file.

Captures _struct.title, _audit_author.name, and _refine category items for pass-through.

merge(other)[source]

Merge other into self. Non-None values in other take precedence.

Pass-through containers are combined (not replaced).

Parameters:

other (RefinementMetadata) – Metadata to merge in (takes precedence for non-None fields).

Returns:

A new merged instance.

Return type:

RefinementMetadata

render_pdb_header()[source]

Render metadata as PDB header records (REMARK 3, TITLE, AUTHOR).

Returns:

Multi-line string ready to insert into a PDB file.

Return type:

str

render_cif_categories()[source]

Render metadata as mmCIF category dictionaries.

Returns a dict of dicts keyed by mmCIF category, with item names as keys and string values. Uses official PDBx/mmCIF field names.

Returns:

Nested dictionary {category: {field: value}}.

Return type:

dict

__init__(program='TORCHREF', program_version='', refinement_method='', resolution_high=None, resolution_low=None, n_reflections_work=None, n_reflections_test=None, n_reflections_all=None, percent_free=None, r_work=None, r_free=None, b_mean_overall=None, b_min=None, b_max=None, rmsd_bond_lengths=None, rmsd_bond_angles=None, n_atoms_total=None, n_atoms_protein=None, n_atoms_solvent=None, solvent_model_ksol=None, solvent_model_bsol=None, cell=None, spacegroup=None, title='', authors=<factory>, passthrough_pdb_remarks=<factory>, passthrough_cif_categories=<factory>, custom_remarks=<factory>)
torchref.io.MTZ

alias of MTZReader

torchref.io.PDB

alias of PDBReader

Subpackages

Submodules