torchref package

TorchRef - A PyTorch-based crystallographic refinement library.

TorchRef provides GPU-accelerated crystallographic structure refinement using PyTorch’s automatic differentiation and nn.Module architecture.

Key Features

  • Native PyTorch integration with nn.Module architecture

  • Automatic differentiation for custom target functions

  • GPU acceleration for structure factor calculations

  • Modular design for easy extension

Quick Start

from torchref import Refinement, ReflectionData, Model

# Load data and model
data = ReflectionData().load_mtz('data.mtz')
model = Model().load_pdb('structure.pdb')

# Run refinement
refinement = Refinement(data=data, model=model, device='cuda')
refinement.run_refinement(macro_cycles=10)

Modules

io

File I/O for MTZ, PDB, CIF formats.

model

Atomic structure models (coordinates, B-factors, occupancies).

refinement

Core refinement framework with targets and weighting schemes.

restraints

Geometry restraints (bonds, angles, torsions, planes). (initialized lazily as it requires downloading the monomer library)

scaling

Structure factor scaling and bulk solvent models.

symmetry

Crystallographic symmetry operations.

alignment

Patterson-based structure alignment.

math_functions

Mathematical utilities for crystallography.

utils

General utilities and debugging tools.

class torchref.ReflectionData(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)[source]

Bases: CrystalDataset, DebugMixin

Container for crystallographic reflection data.

This class handles loading, processing, and accessing reflection data including Miller indices, structure factor amplitudes, intensities, and R-free flags. All data is stored as PyTorch tensors for GPU acceleration.

Parameters:
  • verbose (int, optional) – Verbosity level for logging (0=silent, 1=normal, 2=debug). Default is 1.

  • device (str, optional) – Device to store tensors on (‘cpu’, ‘cuda’, ‘cuda:0’, etc.). Defaults to the configured device.current.

hkl

Miller indices of shape (N, 3), dtype int32.

Type:

torch.Tensor

F

Structure factor amplitudes of shape (N,), dtype float32.

Type:

torch.Tensor

F_sigma

Amplitude uncertainties of shape (N,), dtype float32.

Type:

torch.Tensor

I

Intensities of shape (N,), dtype float32.

Type:

torch.Tensor

I_sigma

Intensity uncertainties of shape (N,), dtype float32.

Type:

torch.Tensor

rfree_flags

R-free test set flags of shape (N,), dtype bool.

Type:

torch.Tensor

cell

Unit cell parameters [a, b, c, alpha, beta, gamma].

Type:

torch.Tensor

spacegroup

Space group symbol.

Type:

str

resolution

Resolution per reflection in Ångströms of shape (N,).

Type:

torch.Tensor

wilson_b

Overall Wilson B-factor in Ų.

Type:

float

Examples

Load reflection data from an MTZ file:

data = ReflectionData(verbose=1, device='cuda')
data.load_mtz('data.mtz')
print(f"Loaded {len(data.hkl)} reflections")
print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
source: ReflectionData | None = None
dataset: DataFrame | None = None
last_op: str | None = None
reader: Any | None = None
__post_init__()[source]

Initialize non-dataclass attributes after dataclass init.

This is called automatically after the dataclass __init__.

load(reader)[source]

Load reflection data using a data reader.

Parameters:

reader (callable) – Data reader object that returns (data_dict, cell, spacegroup) when called. Can be MTZ, ReflectionCIFReader, or other compatible reader.

Returns:

Self, for method chaining.

Return type:

ReflectionData

Raises:

ValueError – If unit cell parameters are missing or no amplitude/intensity data found.

classmethod from_tensors(hkl, F, F_sigma, cell, spacegroup, rfree_flags=None, device=device(type='cpu'), verbose=1)[source]

Construct ReflectionData directly from tensors.

Parameters:
  • hkl (torch.Tensor) – Miller indices of shape (N, 3).

  • F (torch.Tensor) – Structure factor amplitudes of shape (N,).

  • F_sigma (torch.Tensor) – Amplitude uncertainties of shape (N,).

  • cell (Cell) – Unit cell parameters.

  • spacegroup (SpaceGroup) – Space group.

  • rfree_flags (torch.Tensor, optional) – R-free flags of shape (N,), dtype bool. If None, flags are generated automatically (2% free fraction).

  • device (str, optional) – Device for tensors. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

Returns:

Fully initialized reflection data with all cleanup applied.

Return type:

ReflectionData

load_mtz(path, column_names=None)[source]

Load reflection data from MTZ file.

Parameters:
  • path (str) – Path to MTZ file.

  • column_names (dict, optional) – Explicit column name mapping to override automatic detection. Supported keys: "F", "SIGF", "I", "SIGI". Example: {"F": "DFo", "SIGF": "sig_DFo"}.

Returns:

Self, for method chaining.

Return type:

ReflectionData

load_cif(path, data_block=None)[source]

Load reflection data from CIF file.

Parameters:
  • path (str) – Path to CIF file.

  • data_block (str, optional) – Specific data block name to read (e.g., ‘r1vlmsf’). If None, reads the first data block. Useful for multi-dataset CIF files.

Returns:

Self, for method chaining.

Return type:

ReflectionData

static list_cif_data_blocks(path)[source]

List all data blocks available in a CIF file without loading data.

Useful for multi-dataset CIF files to inspect available blocks before loading a specific one.

Parameters:

path (str) – Path to CIF file.

Returns:

Names of all data blocks in the CIF file.

Return type:

list of str

Examples

List and load a specific data block:

blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif')
print(blocks)  # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...]
data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
get_bins(n_bins=20, min_per_bin=100)[source]

Create resolution bins with approximately equal reflection counts.

Parameters:
  • n_bins (int, optional) – Target number of resolution bins. Default is 20.

  • min_per_bin (int, optional) – Minimum reflections per bin. Default is 100.

Returns:

  • bin_indices (torch.Tensor) – Tensor of shape (N,) with bin index for each reflection.

  • n_bins (int) – Actual number of bins created (may be less than target for small datasets).

Return type:

Tuple[Tensor, int]

mean_res_per_bin()[source]

Calculate mean resolution for each bin.

Returns:

Mean resolution for each bin in Ångströms.

Return type:

torch.Tensor

Raises:

ValueError – If bins have not been created yet.

mean_F_per_bin()[source]

Calculate mean structure factor amplitude per resolution bin.

Returns:

Mean F per bin of shape (n_bins,).

Return type:

torch.Tensor

Raises:

ValueError – If bins have not been created yet.

mean_sigma_per_bin()[source]

Calculate mean structure factor uncertainty per resolution bin.

Returns:

Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.

Return type:

torch.Tensor or None

Raises:

ValueError – If bins have not been created yet.

regenerate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100, seed=None, force=False)[source]

Regenerate R-free flags with resolution-stratified sampling.

Parameters:
  • free_fraction (float, optional) – Fraction of reflections to mark as free. Default is 0.02 (2%).

  • n_bins (int, optional) – Target number of resolution bins. Default is 20.

  • min_per_bin (int, optional) – Minimum reflections per resolution bin. Default is 100.

  • seed (int, optional) – Random seed for reproducibility. Default is None.

  • force (bool, optional) – If True, regenerate even if flags already exist. Default is False.

Examples

Generate R-free flags:

# Generate 2% free reflections with reproducible seed
data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42)
# Generate 5% free with 10 bins, overwriting existing
data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
get_structure_factors(as_complex=False)[source]

Get structure factors, optionally as complex numbers.

Parameters:

as_complex (bool, optional) – If True and phases available, return F*exp(i*phi). Default is False.

Returns:

Structure factor amplitudes or complex structure factors.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is loaded.

get_structure_factors_with_sigma()[source]

Get structure factor amplitudes and their uncertainties.

Returns:

  • F (torch.Tensor) – Structure factor amplitudes of shape (N,).

  • F_sigma (torch.Tensor or None) – Uncertainties of shape (N,), or None if not available.

Raises:

ValueError – If no amplitude data is loaded.

Return type:

Tuple[Tensor, Tensor | None]

Examples

Get amplitudes with uncertainties:

F, sigma_F = data.get_structure_factors_with_sigma()
if sigma_F is not None:
    weighted_residual = (F_obs - F_calc) / sigma_F
get_hkl()[source]

Return Miller indices for valid reflections.

Returns:

Miller indices of shape (N, 3), dtype int32.

Return type:

torch.Tensor

Raises:

ValueError – If no Miller indices are loaded.

filter_by_resolution(d_min=None, d_max=None)[source]

Filter reflections by resolution range.

Adds a boolean mask to self.masks for the specified resolution range.

Parameters:
  • d_min (float, optional) – Minimum resolution / high resolution cutoff (e.g., 1.5 Å).

  • d_max (float, optional) – Maximum resolution / low resolution cutoff (e.g., 50.0 Å).

Returns:

Self, for method chaining.

Return type:

ReflectionData

get_mask()[source]

Return combined mask from all active filters.

Returns:

Boolean mask combining all filter conditions.

Return type:

torch.Tensor

cut_res(highres=None, lowres=None)[source]

Filter reflections by resolution range.

Alias for filter_by_resolution with more intuitive naming.

Parameters:
  • highres (float, optional) – High resolution cutoff (small d-spacing, e.g., 1.5 Å). Keeps reflections with d >= highres.

  • lowres (float, optional) – Low resolution cutoff (large d-spacing, e.g., 50.0 Å). Keeps reflections with d <= lowres.

Returns:

Self, for method chaining.

Return type:

ReflectionData

Examples

Filter by resolution:

# Keep reflections between 50 Å and 1.5 Å
filtered = data.cut_res(highres=1.5, lowres=50.0)
# Keep only high-resolution data (< 2 Å)
high_res = data.cut_res(highres=1.0, lowres=2.0)
get_rfree_masks()[source]

Get boolean masks for work and test (free) sets.

Returns:

  • work_mask (torch.Tensor or None) – Boolean tensor for work set (flag != 0).

  • test_mask (torch.Tensor or None) – Boolean tensor for test/free set (flag == 0). Both are None if no R-free flags are available.

Return type:

Tuple[Tensor | None, Tensor | None]

Examples

Separate work and test sets:

work_mask, test_mask = data.get_rfree_masks()
if work_mask is not None:
    F_work = data.F[work_mask]
    F_test = data.F[test_mask]
get_work_set()[source]

Get structure factors for the work set (R-free flag != 0).

Returns:

  • F_work (torch.Tensor) – Structure factors for work set.

  • sigma_work (torch.Tensor or None) – Uncertainties for work set, or None if not available.

Return type:

Tuple[Tensor, Tensor | None]

Notes

Returns full dataset with warning if no R-free flags available.

get_test_set()[source]

Get structure factors for the test set (R-free flag == 0).

Returns:

  • F_test (torch.Tensor) – Structure factors for test/free set.

  • sigma_test (torch.Tensor or None) – Uncertainties for test set, or None if not available.

Raises:

ValueError – If no R-free flags are available.

Return type:

Tuple[Tensor, Tensor | None]

get_max_res()[source]

Return maximum resolution (lowest d-spacing).

Returns:

Maximum resolution in Ångströms.

Return type:

float

get_min_res()[source]

Return minimum resolution (highest d-spacing).

Returns:

Minimum resolution in Ångströms.

Return type:

float

__len__()[source]

Return number of reflections.

Returns:

Number of reflections in the dataset.

Return type:

int

property d_min: float | None

Return maximum resolution (lowest d-spacing).

Returns:

Maximum resolution in Ångströms.

Return type:

float

__repr__()[source]

Return string representation.

Returns:

Summary of reflection data including count, sources, and resolution.

Return type:

str

get_valid_mask()[source]

Return the combined validity mask for all reflections.

This is the mask used to filter reflections in forward(). True indicates a valid (included) reflection, False indicates an excluded one.

Returns:

Boolean mask of shape (N,) where N is the total number of reflections. True = valid/included, False = invalid/excluded.

Return type:

torch.Tensor

Examples

Check validity:

mask = data.get_valid_mask()
print(f"{mask.sum()} of {len(mask)} reflections are valid")
data_indexed()[source]

Return reflection data as indexed (filtered) tensors.

This method filters out invalid reflections and returns smaller tensors containing only valid data. Useful for operations that don’t support MaskedTensors or for writing output files.

Returns:

  • hkl (torch.Tensor) – Miller indices of shape (M, 3) where M is number of valid reflections.

  • F (torch.Tensor) – Structure factor amplitudes of shape (M,).

  • F_sigma (torch.Tensor or None) – Uncertainties of shape (M,) or None.

  • rfree_flags (torch.Tensor or None) – R-free flags of shape (M,) or None.

Return type:

Tuple[Tensor, Tensor, Tensor | None, Tensor | None]

See also

forward

Main method returning MaskedTensors.

Examples

Get indexed data for file writing:

hkl, F, sigma, rfree = data.data_indexed()
F_np = F.cpu().numpy()  # Safe for writing to files
__call__(mask=True, scale=True)[source]

Return core reflection data with MaskedTensors for F and sigma.

F and F_sigma are returned as MaskedTensors which keep all reflections but mark invalid ones as masked. Aggregation operations (sum, mean, etc.) automatically skip masked values. HKL and rfree_flags remain regular tensors.

Parameters:
  • mask (bool, optional) – If True, apply current masks to F and sigma. Default is True.

  • scale (bool, optional) – If True, apply scaling to F and sigma before returning.

Returns:

  • hkl (torch.Tensor) – Miller indices of shape (N, 3). Full size, unfiltered.

  • F (MaskedTensor) – Structure factor amplitudes of shape (N,) with invalid reflections masked.

  • F_sigma (MaskedTensor or None) – Uncertainties of shape (N,) with invalid reflections masked, or None.

  • rfree_flags (torch.Tensor or None) – R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.

Return type:

Tuple[Tensor, MaskedTensor, MaskedTensor, Tensor]

Notes

MaskedTensors:

  • Are PyTorch tensors with an associated boolean mask

  • Aggregations (sum, mean, etc.) skip masked values automatically

  • Element-wise operations preserve the mask

  • Use .get_data() and .get_mask() to access underlying data

  • Use .to_tensor(fill_value) to convert back to regular tensor

  • Note: MaskedTensor is in prototype stage in PyTorch

Loss functions and targets extract valid data from MaskedTensors before computation to work correctly with complex F_calc values.

Examples

Access reflection data with MaskedTensors:

hkl, F, sigma, rfree = data()
print(F.shape)  # Full shape
print(F.sum())  # Only sums valid (unmasked) values

# Access underlying data
valid_mask = F.get_mask()
F_values = F.get_data()[valid_mask]
data_fill_masked(mode='mean')[source]

Return data tensors with missing or flagged reflections filled with.

args:
mode: str, optional

‘mean’ : fill missing/flagged with mean of present data ‘zero’ : fill missing/flagged with zero

__getitem__(key)[source]

Index into the reflection dataset.

Parameters:

key (torch.Tensor) – Boolean mask or integer indices for selection.

Returns:

New ReflectionData object with selected reflections.

Return type:

ReflectionData

__select__(indices, op=None)[source]

Select reflections by boolean mask or integer indices.

Iterates over all dataclass fields generically, so new tensor fields are handled automatically.

Parameters:
  • indices (torch.Tensor) – Boolean mask of shape (N,) or integer indices for selection.

  • op (str, optional) – Operation name for tracking purposes.

Returns:

New ReflectionData object with selected reflections.

Return type:

ReflectionData

sanitize_F()[source]

Remove invalid values from structure factors.

Adds a mask to filter out NaN, Inf, and non-positive values from F and F_sigma.

check_all_data_types()[source]
validate_hkl(hkl_ref)[source]

Expand dataset to match a reference HKL set.

Reorders and expands the current dataset to match the reference HKL set. Reflections present in the reference but missing from this dataset are filled with placeholder values and masked out. This ensures all datasets aligned to the same reference have identical shapes and can be processed together without data loss from intersection operations.

Parameters:

hkl_ref (torch.Tensor) – Reference Miller indices tensor of shape (N, 3), dtype int32. This defines the canonical HKL ordering for all aligned datasets.

Returns:

Self, modified in-place with expanded arrays matching hkl_ref.

Return type:

ReflectionData

Notes

After calling this method: - self.hkl will equal hkl_ref exactly - All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded - Missing reflections are filled with 0 (or appropriate defaults) - A mask ‘hkl_present’ is added marking which reflections have real data - forward() will return MaskedTensors that skip missing reflections

This approach avoids the problem where intersecting many datasets with different outliers/missing reflections causes exponential data loss.

Examples

Align multiple datasets to a common HKL set:

reference_hkl = data1.hkl.clone()
data1.validate_hkl(reference_hkl)
data2.validate_hkl(reference_hkl)
# Now data1 and data2 have identical shapes
assert data1.hkl.shape == data2.hkl.shape
find_outliers(model, scaler, z_threshold=4.0)[source]

Identify outlier reflections based on log-ratio distribution.

Uses the fact that log(F_obs) - log(F_calc) should be normally distributed. Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.

Parameters:
  • model (ModelFT) – ModelFT object to compute structure factors.

  • scaler (Scaler) – Scaler object to scale calculated structure factors.

  • z_threshold (float, optional) – Z-score threshold to classify outliers. Default is 4.0.

Returns:

Boolean mask where True indicates outliers.

Return type:

torch.Tensor

get_log_ratio(model, scaler)[source]

Compute log-ratio between observed and calculated structure factors.

Parameters:
  • model (ModelFT) – ModelFT object to compute structure factors.

  • scaler (Scaler) – Scaler object to scale calculated structure factors.

Returns:

Log-ratio values: log(F_obs) - log(F_calc).

Return type:

torch.Tensor

get_outlier_statistics()[source]

Get statistics about flagged outliers.

Returns:

Dictionary containing: - n_outliers : int - n_total : int - fraction_outliers : float - detection_params : dict or None - outlier_resolution_stats : dict (if resolution available)

Return type:

dict

unpack_one()[source]

Unpack one level of source.

Does not recurse fully and does not flag.

Returns:

Parent source or self if no source.

Return type:

ReflectionData

flag_suspicious_sigma(z_threshold=5.0)[source]

Flag sigma values that deviate significantly from expected distribution.

Sigma values from a detector should follow a log-normal distribution. Values with z-scores beyond threshold are flagged as suspicious.

Parameters:

z_threshold (float, optional) – Z-score threshold for flagging outliers. Default is 5.0.

dump()[source]

Dump all reflection data to console for debugging.

Prints type, shape, and device information for all attributes.

write_mtz(fname, fcalc=None, model_ft=None)[source]

Write reflection data to MTZ file with optional map coefficients.

Parameters:
  • fname (str) – Output MTZ filename.

  • fcalc (torch.Tensor, optional) – Complex calculated structure factors of shape (N,). If provided, computes phases and map coefficients.

  • model_ft (ModelFT, optional) – ModelFT object to compute fcalc if not provided.

Notes

The MTZ file will contain canonical column names:
  • FP, SIGFP: Observed amplitudes and uncertainties

  • I, SIGI: Observed intensities and uncertainties (if available)

  • FreeR_flag: R-free test set flags

  • FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)

  • DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)

Map coefficients are computed as:
  • 2mFo-DFc: 2*Fo - Fc (filled to resolution limit)

  • mFo-DFc: Fo - Fc

Examples

Write MTZ with map coefficients:

data = ReflectionData().load_mtz('observed.mtz')
model = Model().load_pdb('model.pdb')
model_ft = ModelFT(model, data.cell, data.spacegroup)
fcalc = model_ft.forward(data.hkl)
data.write_mtz('output.mtz', fcalc=fcalc)
property centric

Get boolean mask for centric reflections (full size, unfiltered).

Calculates it if not already present. Returns unfiltered centric flags matching the full HKL array size, consistent with how forward() returns full-size arrays when using MaskedTensors.

Returns:

Boolean tensor of shape (N,) where N is total reflections. True indicates centric reflection, False indicates acentric.

Return type:

torch.Tensor or None

calc_patterson(grid_size=None, grid_sampling=1)[source]

Calculate Patterson map of the dataset.

The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw)) is computed via inverse FFT of F². Data is expanded to P1 symmetry using only observed reflections (no filling of missing data).

Parameters:
  • grid_size (tuple of int, optional) – Grid dimensions (Nx, Ny, Nz). If None, automatically determined from unit cell and resolution.

  • grid_sampling (float, optional) – Sampling interval for the grid. Default is 1. This sets the grid so that we sample twice as much as normal for a given resolution

Returns:

Real-valued Patterson map of shape (Nx, Ny, Nz). Origin is at grid position [0, 0, 0].

Return type:

torch.Tensor

possible_hkl()[source]

Generate all possible HKL indices within the resolution limit.

Returns:

Tensor of shape (M, 3) containing all possible Miller indices within the resolution limit defined by self.resolution.

Return type:

torch.Tensor

remap(new_hkl, index_mapping, phase_shifts=None, spacegroup=None, op_name='remap')[source]

Create new ReflectionData with remapped HKL set and data.

This is the core method for index-based transformations of reflection data. It handles remapping all tensor fields based on an index mapping, with support for missing reflections (indicated by -1 in index_mapping).

Parameters:
  • new_hkl (torch.Tensor, shape (M, 3)) – New Miller indices.

  • index_mapping (torch.Tensor, shape (M,), dtype int64) – Maps new indices to original: new[i] = old[index_mapping[i]] Values of -1 indicate missing reflections (filled with defaults).

  • phase_shifts (torch.Tensor, optional, shape (M,)) – Phase offsets to apply (e.g., from symmetry translations).

  • spacegroup (str, int, gemmi.SpaceGroup, or None) – New spacegroup. If None, keeps original.

  • op_name (str) – Operation name for provenance tracking.

Returns:

New object with remapped data. Missing reflections get: - 0.0 for F, I, phase, fom - 1.0 for F_sigma, I_sigma (conservative uncertainty) - True for masks[‘missing’]

Return type:

ReflectionData

fill(d_min=None)[source]

Fill missing reflections within resolution limit.

Generates all possible reflections for the current spacegroup within the resolution limit, identifies which are missing, and creates a complete dataset. Missing reflections are filled with default values.

Parameters:

d_min (float, optional) – High resolution limit in Angstroms. If None, uses the minimum resolution from the current dataset.

Returns:

New ReflectionData with complete set of reflections. Missing reflections have: - F, I, phase, fom = 0.0 - F_sigma, I_sigma = 1.0 - masks[‘missing’] = True

Return type:

ReflectionData

Examples

Fill to completeness:

data = ReflectionData().load_mtz('data.mtz')
data_filled = data.fill(d_min=2.0)
print(f"Original: {len(data)}, Filled: {len(data_filled)}")
expand_to_p1(include_friedel=True, remove_absences=True)[source]

Expand reflection data from asymmetric unit to P1.

Applies all symmetry operations from the current space group to generate all symmetry-equivalent reflections. Returns a NEW ReflectionData object with expanded reflections; does not modify self.

Parameters:
  • include_friedel (bool, default True) – Include Friedel mates (-h, -k, -l). For normal (non-anomalous) scattering, Friedel pairs have identical amplitudes.

  • remove_absences (bool, default True) – Remove systematically absent reflections from output.

Returns:

New ReflectionData object with: - All symmetry-equivalent reflections - spacegroup set to P1 - Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags) - Recalculated resolution values - Provenance tracking via source and last_op attributes

Return type:

ReflectionData

Notes

The expansion handles tensor fields as follows:

  • hkl: Symmetry operations applied, Friedel mates added, duplicates removed

  • F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original

  • phase: Indexed + phase shift from translation component

  • resolution: Recalculated from expanded hkl + cell

  • bin_indices: Cleared (invalidated by expansion)

Examples

Expand to P1 for calculations:

# Load data in original space group
data = ReflectionData().load_mtz('data.mtz')
print(f"Original: {len(data)} reflections, {data.spacegroup}")

# Expand to P1
data_p1 = data.expand_to_p1()
print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}")

# Without Friedel mates (for anomalous data)
data_p1_anom = data.expand_to_p1(include_friedel=False)
reduce_to_spacegroup(spacegroup, include_friedel=True, aggregation='mean')[source]

Reduce P1 reflection data to asymmetric unit of a target spacegroup.

This is the inverse of expand_to_p1(). Takes reflection data in P1 and merges symmetry-equivalent reflections into single ASU reflections using the specified aggregation function.

Parameters:
  • spacegroup (str, int, or gemmi.SpaceGroup) – Target space group specification.

  • include_friedel (bool, default True) – If True, also merge Friedel mates when reducing.

  • aggregation (str, default 'mean') – Aggregation function for merging equivalent reflections: - ‘mean’: Average values (default, good for amplitudes) - ‘sum’: Sum values - ‘first’: Take first valid value (no averaging)

Returns:

New ReflectionData with merged reflections in the target spacegroup.

Return type:

ReflectionData

Notes

Field handling during reduction:

  • F, I: Aggregated using specified function

  • F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for ‘mean’

  • phase: Aggregated via complex averaging (handles phase wrapping)

  • fom: Weighted by amplitude during phase averaging

  • rfree_flags: OR operation (True if any equivalent is True)

Examples

Reduce symmetry-expanded data:

# Expand to P1 for calculations, then reduce back
data_p1 = data.expand_to_p1()
# ... modify F_p1 ...
data_merged = data_p1.reduce_to_spacegroup('P21')

# Reduce with sum instead of mean
data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
canonicalize(include_friedel=True)[source]

Return new ReflectionData with HKL in standard CCP4 ASU form.

Remaps all Miller indices to the canonical CCP4 asymmetric unit representative using gemmi.ReciprocalAsu, adjusts phases accordingly, and sorts reflections lexicographically by (h, k, l).

Parameters:

include_friedel (bool, default True) – Whether Friedel mates are considered equivalent.

Returns:

New object with canonicalized, sorted Miller indices.

Return type:

ReflectionData

get_scattering_vectors()[source]

Get scattering vectors (s-vectors) from hkl and cell.

The s-vector for a reflection hkl is defined as:

s = B* @ hkl

where B* is the reciprocal basis matrix.

Returns:

s_vectors – Reciprocal space vectors in Angstroms^-1, shape (N, 3).

Return type:

torch.Tensor

Raises:

ValueError – If hkl or cell is not available.

get_radial_shells(n_shells=20, d_min=None, d_max=None)[source]

Create uniform radial shells in 1/d space for normalization.

This creates shells with uniform spacing in reciprocal space (1/d), different from get_bins() which creates equal-count bins.

Parameters:
  • n_shells (int) – Number of radial shells. Default is 20.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

Returns:

  • shell_edges (torch.Tensor) – Shell boundaries in Angstroms^-1, shape (n_shells+1,).

  • shell_centers (torch.Tensor) – Shell centers in Angstroms^-1, shape (n_shells,).

  • shell_indices (torch.Tensor) – Shell index for each reflection, shape (N,). Values -1 for out-of-range.

Return type:

Tuple[Tensor, Tensor, Tensor]

fit_anisotropy(n_shells=20, d_min=None, d_max=None, n_iterations=100, verbose=None)[source]

Fit anisotropy correction parameters to minimize CV within shells.

Optimizes U parameters so that corrected F² values have minimal coefficient of variation within each resolution shell.

Parameters:
  • n_shells (int) – Number of resolution shells for variance calculation.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

  • n_iterations (int) – Number of optimization iterations.

  • verbose (bool, optional) – Print progress. If None, uses self.verbose.

Returns:

U – Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,). Also stored in self.U_aniso.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is available.

setup_anisotropy(U_aniso=None)[source]

Setup anisotropy correction parameters. :param U_aniso: Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).

If None, uses Initializes as zeros.

apply_anisotropy_correction(U_aniso=None)[source]

Apply anisotropy correction to F² values.

Parameters:

U_aniso (torch.Tensor, optional) – Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,). If None, uses self.U_aniso (must have called fit_anisotropy first).

Returns:

  • F_corrected (torch.Tensor) – Anisotropy-corrected F values, shape (N,).

  • sigma_F_corrected (torch.Tensor) – Uncertainties of corrected F values, shape (N,).

Raises:

ValueError – If no U parameters available and none provided.

Return type:

Tensor

compute_e_values(n_shells=20, d_min=None, d_max=None, apply_anisotropy=True, fit_anisotropy=True, verbose=None)[source]

Compute E-values with optional anisotropy correction.

E-values are normalized structure factors where <E²> = 1 within each resolution shell. Anisotropy correction can be applied first to account for directional variation in diffraction.

Parameters:
  • n_shells (int) – Number of resolution shells for normalization.

  • d_min (float, optional) – High resolution limit in Angstroms. If None, uses dataset minimum.

  • d_max (float, optional) – Low resolution limit in Angstroms. If None, uses dataset maximum.

  • apply_anisotropy (bool) – If True, apply anisotropy correction before E-value calculation.

  • fit_anisotropy (bool) – If True and apply_anisotropy is True, fit anisotropy parameters. If False and apply_anisotropy is True, uses existing self.U_aniso.

  • verbose (bool, optional) – Print progress. If None, uses self.verbose.

Returns:

E – E-values, shape (N,). Also stored in self.E. self.E_squared is also populated with E² values.

Return type:

torch.Tensor

Raises:

ValueError – If no amplitude data is available.

Examples

Compute E-values with automatic anisotropy correction:

data = ReflectionData().load_mtz('data.mtz')
E = data.compute_e_values(n_shells=30)
print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")

Compute E-values without anisotropy correction:

E = data.compute_e_values(apply_anisotropy=False)
setup_scale(scale=None)[source]

Set overall scale factor, parametrized in log space.

Parameters:

scale (float, optional) – If provided, sets the scale factor directly. If None, computes scale to make mean F equal to 1.0.

Returns:

The scale factor applied.

Return type:

float

get_corrected_data()[source]

Get scaled amplitude and uncertainty tensors.

Applies the current scale factor (from self.log_scale) to F and F_sigma.

Returns:

Scaled F and F_sigma tensors. If log_scale is None, returns original.

Return type:

Tuple[torch.Tensor, torch.Tensor]

parameters()[source]

Get list of learnable parameters for optimization.

Returns:

iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.

Return type:

iter[Parameter]

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _centric=None, _n_bins=None, _FrenchWilson=None, source=None, dataset=None, last_op=None, reader=None)
class torchref.DatasetCollection(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)[source]

Bases: CrystalDataset

Container for multiple related crystal datasets.

All datasets share a common HKL set for efficient computation. Datasets are aligned using the first dataset as a reference, with missing reflections in subsequent datasets masked out.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=normal, 2=debug). Default is 1.

  • device (str, optional) – Device for tensors (‘cpu’, ‘cuda’, etc.). Defaults to the configured device.current.

hkl

Common HKL set for all datasets.

Type:

torch.Tensor

n_datasets

Number of datasets in collection.

Type:

int

Examples

from torchref.io import DatasetCollection, ReflectionData

collection = DatasetCollection(device='cuda')

native = ReflectionData().load_mtz('native.mtz')
derivative = ReflectionData().load_mtz('derivative.mtz')

collection.add_dataset('native', native, set_as_reference=True)
collection.add_dataset('derivative', derivative)

for name, dataset in collection:
    print(f"{name}: {len(dataset)} reflections")

# Access by name
native_F = collection['native'].F
add_dataset(name, dataset, set_as_reference=False)[source]

Add a dataset to the collection.

Parameters:
  • name (str) – Identifier for this dataset.

  • dataset (ReflectionData) – The dataset to add.

  • set_as_reference (bool, optional) – If True, use this dataset’s HKL as the reference. Default is False, but the first dataset added automatically becomes the reference.

Returns:

Self, for method chaining.

Return type:

DatasetCollection

Raises:

ValueError – If a dataset with the same name already exists.

Examples

collection = DatasetCollection()
collection.add_dataset('native', native_data, set_as_reference=True)
collection.add_dataset('derivative', derivative_data)
property hkl: Tensor | None

Common HKL set for all datasets.

property datasets: Dict[str, ReflectionData]

Access all datasets as a dictionary.

property n_datasets: int

Number of datasets in collection.

property reference_dataset: str | None

Name of the reference dataset.

property spacegroup: str | None

Space group of the reference dataset.

__getitem__(name)[source]

Get dataset by name.

Parameters:

name (str) – Name of the dataset.

Returns:

The requested dataset.

Return type:

ReflectionData

Raises:

KeyError – If dataset name not found.

__iter__()[source]

Iterate over (name, dataset) pairs in order of addition.

Yields:

tuple of (str, ReflectionData) – Name and dataset for each dataset in collection.

__len__()[source]

Number of reflections in common HKL set.

__contains__(name)[source]

Check if dataset exists in collection.

__call__(mask=True)[source]

Return all datasets’ data scaled if scale factors are set.

Parameters:

mask (bool, optional) – Whether to apply masking. Default is True.

Returns:

Dictionary mapping name to (hkl, F, F_sigma, rfree) tuples.

Return type:

dict

scale()[source]

Scale all datasets to a common reference scale. This method optimizes the scaling parameters of all non-reference datasets to minimize the mean squared error between their structure factors and those of the reference dataset. The optimization corrects for both overall scale differences and anisotropy. The method uses the L-BFGS optimizer with strong Wolfe line search to iteratively refine the scaling parameters over multiple optimization steps.

The collection instance, allowing for method chaining.

Raises:

ValueError – If no reference dataset has been set prior to calling this method or only a reference dataset exists. Make sure to have at least 2 datasets duh…

Notes

The reference dataset must be set before calling this method using the appropriate setter. All datasets except the reference will have their scaling parameters optimized. “”” Scale all datasets to the same overall scale. Corrects overall scale and anisotropy based on the reference dataset.

Returns:

for method chaining.

Return type:

self

keys()[source]

Return list of dataset names.

values()[source]

Return list of datasets.

items()[source]

Return list of (name, dataset) tuples.

__init__(hkl=None, F=None, F_sigma=None, I=None, I_sigma=None, rfree_flags=None, resolution=None, bin_indices=None, outlier_flags=None, phase=None, fom=None, _centric_flags=None, E=None, E_squared=None, F_squared_corrected=None, U_aniso=None, radial_shell_indices=None, cell=None, spacegroup=None, device=<factory>, verbose=1, rfree_source=None, amplitude_source=None, intensity_source=None, phase_source=None, wilson_b=None, wilson_b_structure=None, wilson_b_solvent=None, wilson_k_sol=None, outlier_detection_params=None, _datasets=<factory>, _dataset_order=<factory>, _reference_dataset=None, _common_hkl=None, _cell=None, _spacegroup=None, _resolution=None, _scale_factors=<factory>)
get(name, default=None)[source]

Get dataset by name with default fallback.

__repr__()[source]

String representation of collection.

class torchref.Model(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]

Bases: DeviceMixin, DebugMixin, Module

Base model class for atomic structure models using PyTorch.

This class provides the foundation for managing atomic structure data including coordinates, atomic displacement parameters (ADPs), and occupancies. It supports both empty initialization for state_dict loading and file-based initialization from PDB/CIF files.

Parameters:
  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.

  • verbose (int, optional) – Verbosity level for logging. Default is 1.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.

xyz

Atomic coordinates tensor with shape (n_atoms, 3).

Type:

MixedTensor

adp

Atomic displacement parameters (isotropic B-factors) with shape (n_atoms,).

Type:

PositiveMixedTensor

u

Anisotropic displacement parameters with shape (n_atoms, 6).

Type:

MixedTensor

occupancy

Atomic occupancies with values in [0, 1].

Type:

OccupancyTensor

pdb

DataFrame containing atomic model data.

Type:

pandas.DataFrame

cell

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

Type:

Cell

spacegroup

Space group object.

Type:

gemmi.SpaceGroup

symmetry

Symmetry operations handler for this space group.

Type:

Symmetry

initialized

Whether the model has been initialized with data.

Type:

bool

Examples

Empty initialization for state_dict loading:

model = Model()
model.load_state_dict(torch.load('model.pt'))

File-based initialization:

model = Model()
model.load_pdb('structure.pdb')
__init__(dtype_float=torch.float32, verbose=1, device=device(type='cpu'), strip_H=True)[source]

Initialize an empty Model shell.

Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().

Parameters:
  • dtype_float (torch.dtype, optional) – Data type for floating point tensors. Defaults to the configured dtypes.float.

  • verbose (int, optional) – Verbosity level for logging. Default is 1.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • strip_H (bool, optional) – Whether to strip hydrogen atoms when loading. Default is True.

__bool__()[source]

Return the initialization status when used in boolean context.

property exclude_H_from_sf: bool

Whether to exclude hydrogen atoms from structure factor calculation.

When True, H atoms are excluded from get_iso() / get_aniso() so they do not contribute to Fcalc. They still participate in geometry and VDW restraints. Default is False.

property cell: Cell | None

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

Returns:

The unit cell object, or None if not set.

Return type:

Cell or None

property spacegroup: SpaceGroup | None

Space group object.

Returns:

The space group object, or None if not set.

Return type:

gemmi.SpaceGroup or None

property symmetry: SpaceGroup | None

Symmetry operations handler for this space group.

Returns the same SpaceGroup object as self.spacegroup — the separate Symmetry wrapper was redundant since Symmetry is just an alias.

Returns:

The space group object, or None if not set.

Return type:

SpaceGroup or None

property inv_fractional_matrix: Tensor

Fractionalization matrix B^-1 (Cartesian -> fractional).

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) fractionalization matrix.

Return type:

torch.Tensor

property fractional_matrix: Tensor

Orthogonalization matrix B (fractional -> Cartesian).

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) orthogonalization matrix.

Return type:

torch.Tensor

property recB: Tensor

Reciprocal basis matrix with [a*, b*, c*] as rows.

Delegates to Cell for automatic caching and device/dtype handling.

Returns:

Shape (3, 3) matrix where rows are the reciprocal basis vectors.

Return type:

torch.Tensor

property Z: Tensor

Atomic numbers for all atoms.

Returns:

Tensor of atomic numbers with shape (n_atoms,).

Return type:

torch.Tensor

get_P1_parameters_iso()[source]

Get model parameters transformed to P1 space for optimization.

This is useful for optimizers that do not handle symmetry directly or MD.

Returns:

  • xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.

  • adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.

  • occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.

  • A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.

  • B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

get_MD_parameters()[source]

Get model parameters prepared for molecular dynamics simulation.

Returns all P1-expanded parameters plus atomic numbers for MD engines.

Returns:

  • xyz_p1 (torch.Tensor) – Fractional coordinates expanded to P1 space.

  • adp_p1 (torch.Tensor) – Isotropic ADPs expanded to P1 space.

  • occupancy_p1 (torch.Tensor) – Occupancies expanded to P1 space.

  • A (torch.Tensor) – Scattering factor A coefficients expanded to P1 space.

  • B (torch.Tensor) – Scattering factor B coefficients expanded to P1 space.

  • Z_p1 (torch.Tensor) – Atomic numbers expanded to P1 space.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

property parametrization

(A, B)}.

The parametrization is built lazily on first access.

Returns:

Dictionary mapping element symbols to tuples of (A, B) tensors.

Return type:

dict

Type:

ITC92 parametrization dictionary {element

get_scattering_params_iso()[source]

Get ITC92 scattering parameters (A, B) for isotropic atoms.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_iso_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_iso_atoms, 5).

get_scattering_params_aniso()[source]

Get ITC92 scattering parameters (A, B) for anisotropic atoms.

Returns:

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_aniso_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_aniso_atoms, 5).

set_restraints_cif(cif_path)[source]

Set CIF path for lazy restraint building.

Parameters:
  • cif_path (str or list of str) – Path(s) to CIF restraints dictionary file(s).

  • self (return) – For method chaining

property restraints

Lazy restraints property.

The restraints are built on first access using the model’s pdb DataFrame and the CIF path set via set_restraints_cif().

Returns:

The restraints object containing bond, angle, torsion, etc. restraints.

Return type:

RestraintsNew

bond_deviations()[source]

Compute bond length deviations using current xyz coordinates.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected bond lengths in Angstroms.

  • sigmas (torch.Tensor) – Standard deviations from CIF library in Angstroms.

angle_deviations()[source]

Compute angle deviations using current xyz coordinates.

Returns:

  • deviations (torch.Tensor) – Calculated minus expected angles in radians.

  • sigmas (torch.Tensor) – Standard deviations in radians.

torsion_deviations_with_sigmas()[source]

Compute torsion deviations (wrapped for periodicity) and sigmas.

Returns:

  • deviations_rad (torch.Tensor) – Wrapped deviations in radians.

  • sigmas_deg (torch.Tensor) – Standard deviations in degrees (for von Mises NLL).

load(reader)[source]
load_pdb(file)[source]

Load atomic model from PDB file.

Parameters:

file (str) – Path to PDB file.

Returns:

Self, for method chaining.

Return type:

Model

load_cif(file)[source]

Load atomic model from mmCIF file.

Parameters:

file (str) – Path to CIF/mmCIF file.

Returns:

Self, for method chaining.

Return type:

Model

property chain_sequences: List[Tuple[str, str]]

Per-chain amino acid sequences as single-letter codes.

Excludes HETATM records. Gaps in residue numbering are filled with ?. Non-standard residues are mapped to X.

Returns:

Ordered list of (chain_id, sequence_string). E.g. [("A", "MKVL??GAST"), ("B", "ACDEFG")].

Return type:

list of (str, str)

get_chain_residues()[source]

Per-chain residue names as 3-letter codes (for IHM/CIF writing).

Excludes HETATM records. Unlike chain_sequences, returns the raw 3-letter codes without gap filling.

Returns:

Ordered list of (chain_id, [resname, ...]).

Return type:

list of (str, list of str)

update_pdb()[source]
get_vdw_radii()[source]

Get van der Waals radii for all atoms based on their elements.

Caches the result in self.vdw_radii for future calls.

Returns:

Van der Waals radii for each atom with shape (n_atoms,).

Return type:

torch.Tensor

to(*args, **kwargs)[source]

Move Model and rebuild device-specific SF indices.

Delegates to DeviceMixin, which walks self.__dict__ (picking up self.cell, self.altloc_pairs, self._restraints and all registered parameters / buffers), refreshes the self.device tracker, and invalidates caches. Afterwards this override rebuilds the precomputed SF indices on the new device.

copy()[source]

Create a deep copy of the Model.

Creates a complete independent copy including all registered buffers, module parameters, PDB DataFrame, and spacegroup information.

Returns:

A new Model instance with copied data.

Return type:

Model

Examples

model = Model().load_pdb('structure.pdb')
model_copy = model.copy()
# model_copy is independent, changes won't affect model
write_pdb(filename, metadata=None)[source]

Write model to PDB file with optional metadata header.

Parameters:
  • filename (str) – Output PDB file path.

  • metadata (RefinementMetadata, optional) – Metadata to render as PDB header (REMARK 3, TITLE, etc.).

write_cif(filename, metadata=None)[source]

Write model to mmCIF file with optional metadata.

Parameters:
  • filename (str) – Output mmCIF file path.

  • metadata (RefinementMetadata, optional) – Metadata to include (refinement statistics, title, etc.).

get_iso()[source]

Return per-atom parameters for the isotropic atom subset.

Selects atoms whose ADP is a single scalar b (i.e. not anisotropic). The subset is defined by ~self.aniso_flag — intersected with self._heavy_atom_mask when _exclude_H_from_sf is enabled — and is precomputed as self._iso_indices at init / whenever the mask changes.

Returns:

  • xyz (torch.Tensor, shape (n_iso, 3)) – Cartesian coordinates of the isotropic atoms (Å).

  • adp (torch.Tensor, shape (n_iso,)) – Isotropic B-factors (Ų).

  • occupancy (torch.Tensor, shape (n_iso,)) – Occupancies in [0, 1].

Notes

When every atom is isotropic and no H exclusion is active — self._iso_covers_all is True, the common protein-refinement case — the per-atom indexing is skipped and self.xyz(), self.adp(), self.occupancy() are returned directly.

Motivation: self.xyz()[idx] is a no-op forward when idx = arange(N), but its backward routes through PyTorch’s aten::_index_put_impl_(accumulate=True), which performs a cub::DeviceRadixSortOnesweepKernel over len(idx) indices followed by a deduplicated scatter (~50-150 µs/iter per gather on A100 / 1DAW). Skipping the gather avoids that cost.

set_default_masks()[source]
PARAM_TYPES: Tuple[str, ...] = ('xyz', 'adp', 'u', 'occupancy')
parameters_of_types(types)[source]

Return the leaf ``nn.Parameter``s for the named parameter types.

Used by refinement entry points (refine_xyz, refine_adp, …) to construct an optimizer over only the leaves the caller intends to update. LossState.step then uses the optimizer’s param groups as intent and disables requires_grad on any other leaves the loss also touches.

Parameters:

types (Iterable[str]) – Subset of Model.PARAM_TYPES: "xyz", "adp", "u", "occupancy". Unknown names are silently skipped.

Returns:

The refinable_params leaf for each requested type, in the order the types were given.

Return type:

list of nn.Parameter

freeze(target)[source]
freeze_all()[source]
unfreeze_all()[source]
unfreeze(target)[source]
update_mask_from_selection(selection_string, target, mode='set', freeze=True)[source]

Update the refinable mask for a parameter using Phenix-style selection syntax.

This method updates the internal mask buffer (xyz_mask, adp_mask, u_mask, or occupancy_mask) based on the selection. The updated mask is NOT automatically applied to the parameter tensors - use apply_mask_to_parameter() to apply it.

Parameters:
  • selection_string (str) – Phenix-style selection string (see parse_phenix_selection docs).

  • target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.

  • mode (str, optional) – How to combine with current mask: - ‘set’: Replace mask with selection (default) - ‘add’: Add selection to current mask - ‘remove’: Remove selection from current mask

  • freeze (bool, optional) – If True (default), selected atoms will be frozen (mask=False). If False, selected atoms will be unfrozen (mask=True).

Raises:

ValueError – If target is not recognized or selection syntax is invalid.

Examples

# Freeze chain A coordinates
model.update_mask_from_selection("chain A", "xyz", mode='set', freeze=True)
model.apply_mask_to_parameter("xyz")

# Unfreeze backbone atoms
model.update_mask_from_selection("name CA or name C or name N", "xyz", freeze=False)
model.apply_mask_to_parameter("xyz")
apply_mask_to_parameter(target)[source]

Apply the current mask buffer to the parameter tensor.

Takes the current state of the mask buffer (xyz_mask, adp_mask, etc.) and applies it to the corresponding parameter tensor’s refinable mask.

Parameters:

target (str) – Parameter to update: ‘xyz’, ‘adp’, ‘u’, or ‘occupancy’.

Raises:

ValueError – If target is not recognized.

Examples

model.update_mask_from_selection("chain A", "xyz", freeze=True)
model.apply_mask_to_parameter("xyz")
freeze_selection(selection_string, targets='all')[source]

Freeze atoms matching a Phenix-style selection for specified parameters.

Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • targets (str or list of str, optional) – Parameter(s) to freeze. Can be: - ‘all’: Freeze xyz, adp, u, and occupancy (default) - str: Single parameter (‘xyz’, ‘adp’, ‘u’, ‘occupancy’) - list: List of parameters, e.g., [‘xyz’, ‘adp’]

Examples

# Freeze all parameters for chain A
model.freeze_selection("chain A", targets='all')

# Freeze only coordinates for residues 10-20
model.freeze_selection("resseq 10:20", targets='xyz')
unfreeze_selection(selection_string, targets='all')[source]

Unfreeze atoms matching a Phenix-style selection for specified parameters.

Convenience method that combines update_mask_from_selection() and apply_mask_to_parameter() into a single call.

Parameters:
  • selection_string (str) – Phenix-style selection string.

  • targets (str or list of str, optional) – Parameter(s) to unfreeze. Can be: - ‘all’: Unfreeze xyz, adp, u, and occupancy (default) - str: Single parameter (‘xyz’, ‘adp’, ‘u’, ‘occupancy’) - list: List of parameters, e.g., [‘xyz’, ‘adp’]

Examples

# Unfreeze all parameters for chain A
model.unfreeze_selection("chain A", targets='all')

# Unfreeze only coordinates for backbone atoms
model.unfreeze_selection("name CA or name C or name N", targets='xyz')
get_aniso()[source]

Return per-atom parameters for the anisotropic atom subset.

Selects atoms whose ADP is the 6-element anisotropic tensor u = (u11, u22, u33, u12, u13, u23). The subset is defined by self.aniso_flag — intersected with self._heavy_atom_mask when _exclude_H_from_sf is enabled — and is precomputed as self._aniso_indices at init / whenever the mask changes.

Returns:

  • xyz (torch.Tensor, shape (n_aniso, 3)) – Cartesian coordinates of the anisotropic atoms (Å). Empty tensor when there are no anisotropic atoms.

  • u (torch.Tensor, shape (n_aniso, 6)) – Anisotropic U components (Ų) in the order (u11, u22, u33, u12, u13, u23). Empty when n_aniso == 0.

  • occupancy (torch.Tensor, shape (n_aniso,)) – Occupancies in [0, 1]. Empty when n_aniso == 0.

Notes

When there are no anisotropic atoms — self._aniso_is_empty is True, the common protein-refinement case — three empty placeholder tensors are returned without calling the MixedTensors at all. This avoids both the wrapped forward .clone() and the slow aten::_index_put_impl_ backward path that the self.xyz()[idx] gather would otherwise generate (see get_iso() for the same rationale).

parameters(recurse=True)[source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
named_mixed_tensors()[source]

Iterate over all MixedTensor attributes with their names.

Yields:

Tuple of (name, MixedTensor)

print_parameters_info()[source]

Print information about all MixedTensor parameters.

register_alternative_conformations()[source]

Identify and register all alternative conformation groups in the structure.

For each residue that has alternative conformations (altloc A, B, C, etc.), this method identifies all atoms belonging to each conformation and stores their indices as tensors in a tuple.

The result is stored in self.altloc_pairs as a list of tuples, where each tuple contains tensors of atom indices for each alternative conformation.

Examples

For a residue with conformations A and B:

# Conformation A has atoms at indices [100, 101, 102, ...]
# Conformation B has atoms at indices [110, 111, 112, ...]
# Result: [(tensor([100, 101, 102, ...]), tensor([110, 111, 112, ...])), ...]

For a residue with conformations A, B, C:

# Result: [(tensor([200, 201, ...]), tensor([210, 211, ...]), tensor([220, 221, ...])), ...]
shake_coords(stddev)[source]

Apply random Gaussian noise to atomic coordinates.

Perturbs the atomic coordinates by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.

Parameters:

stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstroms.

shake_adp(stddev)[source]

Apply random Gaussian noise to ADPs (atomic displacement parameters).

Perturbs the ADPs by adding Gaussian noise with a specified standard deviation. The noise is applied to all atoms.

Parameters:

stddev (float) – Standard deviation of the Gaussian noise to be added, in Angstrom^2.

generate_hydrogens(mon_lib_path=None)[source]

Generate hydrogen atoms for the current model using gemmi.

Places hydrogens at ideal geometry using the CCP4 monomer library and gemmi’s topology engine. Returns a new Model instance with hydrogens added; the original model is not modified.

Parameters:

mon_lib_path (str, optional) – Path to CCP4 monomer library directory. If None, uses the monomer library bundled with torchref (covers standard amino acids and common small molecules).

Returns:

A new Model instance with hydrogen atoms added (strip_H=False). Unknown residues are skipped silently.

Return type:

Model

Notes

Requires gemmi (already a torchref dependency). Heavy-atom coordinates from the current model state are used, so call this after any coordinate changes you want reflected in the H positions.

Examples

>>> model_no_h = Model().load_pdb('structure.pdb')
>>> model_with_h = model_no_h.generate_hydrogens()
>>> print(model_with_h.Z.shape)   # more atoms than model_no_h
strip_altlocs()[source]

Return a new model with alternate conformations removed.

For each residue that has multiple altlocs, the conformer with highest average occupancy is kept (ties broken alphabetically). The altloc column is cleared to "" in the returned model. The original model is not modified.

strip_hydrogens()[source]

Return a new model with hydrogen atoms removed.

The returned model has consistent DataFrame and tensors (xyz, adp, occupancy) with H atoms excluded. The original model is not modified.

Returns:

New model without hydrogen atoms.

Return type:

Model

hydrogenate(verbose=0, optimize=False, lbfgs_steps=3, max_iter=20)[source]

Return a new model with hydrogen atoms placed via Kabsch alignment.

Uses torchref’s monomer library to identify missing H atoms, places them by SVD-aligning ideal monomer coordinates onto the current model coordinates, then corrects each H to sit at ideal bond length from its parent atom. The original model is not modified.

Parameters:
  • verbose (int, optional) – Verbosity level (0=silent, 1=summary, 2=detailed). Default 0.

  • optimize (bool, optional) – If True, run a short LBFGS geometry optimization on H positions after placement. Default False (Kabsch placement only).

  • lbfgs_steps (int, optional) – Number of LBFGS outer steps (only when optimize=True). Default 3.

  • max_iter (int, optional) – Max line-search iterations per LBFGS step. Default 20.

Returns:

New model with hydrogen atoms added. All parameters are unfrozen in the returned model.

Return type:

Model

adp_loss()[source]

Compute the ADP regularization loss.

This loss encourages ADPs to have similar values across the structure, helping to prevent overfitting during refinement.

Returns:

Scalar tensor representing the ADP loss.

Return type:

torch.Tensor

adp_nll_loss(target_log_std=0.2)[source]

Compute negative log-likelihood of ADPs assuming Gaussian distribution in log-space.

This regularization penalizes ADPs that deviate from a target distribution with a FIXED standard deviation (hyperparameter), avoiding circular dependency on the current distribution’s statistics.

The NLL for a Gaussian distribution in log-space is:

NLL = 0.5 * mean[(log_adp - mu)^2 / sigma^2 + log(2*pi*sigma^2)]

Where mu is the mean of log-space ADPs (computed from current data) and sigma is the FIXED target standard deviation (hyperparameter).

Parameters:

target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. - 0.1 = very tight (ADPs within ~10% of mean) - 0.2 = moderate spread (ADPs within ~20% of mean) [RECOMMENDED] - 0.3 = looser spread (ADPs within ~30% of mean)

Returns:

Scalar tensor representing the NLL. Lower values indicate the distribution is closer to the target Gaussian with fixed sigma.

Return type:

torch.Tensor

Examples

# During refinement
structure_factor_loss = compute_structure_factor_loss()
nll_reg = model.adp_nll_loss(target_log_std=0.2)
total_loss = structure_factor_loss + 0.01 * nll_reg
total_loss.backward()

Notes

Uses FIXED sigma (no circular dependency on current distribution). Smaller target_log_std = stronger regularization (tighter distribution).

adp_nll_loss_per_atom(target_log_std=0.2)[source]

Compute per-atom negative log-likelihood for ADPs in log-space.

Returns the NLL contribution for each individual atom, useful for identifying outliers or applying atom-specific regularization weights.

The per-atom NLL is:

NLL_i = 0.5 * [(log_adp_i - mu)^2 / sigma^2 + log(2*pi*sigma^2)]
Parameters:

target_log_std (float, optional) – Fixed target standard deviation in log-space. Default is 0.2.

Returns:

Tensor of shape (n_atoms,) with per-atom NLL values. Higher values indicate atoms farther from the mean.

Return type:

torch.Tensor

Examples

# Get per-atom NLL
atom_nll = model.adp_nll_loss_per_atom(target_log_std=0.2)
# Identify outlier atoms (high NLL)
threshold = atom_nll.mean() + 2 * atom_nll.std()
outliers = atom_nll > threshold
adp_kl_divergence_loss(target_log_std=0.2)[source]

Compute KL divergence between log ADP distribution and target Gaussian.

Measures how different the current log ADP distribution is from a target Gaussian distribution with the current mean of log ADPs and a fixed target standard deviation.

KL divergence formula for two Gaussians with same mean:

KL(q || p) = log(sigma_target/sigma_data) + sigma_data^2 / (2*sigma_target^2) - 0.5
Parameters:

target_log_std (float, optional) – Target standard deviation in log-space. Default is 0.2. Controls how tightly ADPs should cluster.

Returns:

Scalar KL divergence value (always >= 0). 0 means distributions match perfectly. Higher values mean more deviation from target.

Return type:

torch.Tensor

Examples

# Use in loss function
loss = xray_loss + w_adp * model.adp_kl_divergence_loss(0.2)

Notes

Lower target_log_std = stronger regularization (tighter distribution). Mean is detached so it adapts to the natural scale of the data.

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the Model.

Includes all registered buffers, model parameters (xyz, b, u, occupancy), PDB DataFrame, and metadata (spacegroup, device, dtype, etc.).

Parameters:
  • destination (dict, optional) – Optional dict to populate with state.

  • prefix (str, optional) – Prefix for parameter names. Default is ‘’.

  • keep_vars (bool, optional) – Whether to keep variables in computational graph. Default is False.

Returns:

Complete state dictionary.

Return type:

dict

save_state(path)[source]

Save the complete state of the model to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the model from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, optional) – Whether to strictly enforce that keys match. Default is True.

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]

Create a fully initialized Model from a state dictionary.

This is the recommended way to restore a Model from a saved state. Creates an instance with properly initialized submodules, then loads the state.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • dtype_float (torch.dtype, optional) – Float dtype for tensors. Defaults to the configured dtypes.float.

Returns:

Fully initialized instance with restored state.

Return type:

Model

get_selection_mask(selection)[source]

Return a boolean mask for atoms matching a Phenix-style selection.

This is a convenience method that wraps parse_phenix_selection() to return a mask that can be used directly with MixedTensor.set() or other operations requiring atom selection.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

Boolean tensor of shape (n_atoms,) where True indicates selected atoms.

Return type:

torch.Tensor

Raises:

Examples

model = Model().load_pdb('structure.pdb')
# Get mask for chain A
mask = model.get_selection_mask("chain A")
# Use mask to update coordinates
new_coords = model.xyz()[mask] + translation
model.xyz.set(new_coords, mask)
# Get mask for backbone atoms
backbone_mask = model.get_selection_mask("name CA or name C or name N or name O")
# Complex selection with parentheses
mask = model.get_selection_mask("chain A and (resname ALA or resname GLY)")
select(selection)[source]

Return a new Model containing only atoms matching the Phenix-style selection.

Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.

Return type:

Model

Raises:
  • RuntimeError – If the model has not been initialized.

  • ValueError – If selection syntax is invalid or no atoms are selected.

Examples

model = Model().load_pdb('structure.pdb')
# Select chain A
chain_a = model.select("chain A")
# Select backbone atoms
backbone = model.select("name CA or name C or name N or name O")
# Select residues 10-50 of chain B
region = model.select("chain B and resseq 10:50")
# Select all except water
no_water = model.select("not resname HOH")
# Complex selection with parentheses
complex_sel = model.select("chain A and (resname ALA or resname GLY)")

Notes

This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.

xyz_fractional()[source]

Return atomic coordinates in fractional space.

Converts Cartesian coordinates to fractional coordinates using the inverse fractional matrix.

Returns:

Tensor of shape (n_atoms, 3) with fractional coordinates.

Return type:

torch.Tensor

rotate(rotation_matrix, center=None)[source]

Apply rotation to atomic coordinates (in-place).

Rotates all atoms around a specified center point. The rotation is applied using the formula: xyz_new = R @ (xyz - center) + center

Parameters:
  • rotation_matrix (torch.Tensor) – 3x3 rotation matrix. Should be orthogonal (R^T @ R = I).

  • center (torch.Tensor, optional) – Center of rotation with shape (3,). If None, uses the centroid of all atomic coordinates.

Returns:

Self, for method chaining.

Return type:

Model

Examples

# Rotate 90 degrees around Z-axis
import math
angle = math.pi / 2
R = torch.tensor([
    [math.cos(angle), -math.sin(angle), 0],
    [math.sin(angle), math.cos(angle), 0],
    [0, 0, 1]
])
model.rotate(R)

# Rotate around a specific point
center = torch.tensor([10.0, 20.0, 30.0])
model.rotate(R, center=center)
translate(translation, fractional=False)[source]

Apply translation to atomic coordinates (in-place).

Translates all atoms by a specified vector. The translation can be given in either Cartesian or fractional coordinates.

Parameters:
  • translation (torch.Tensor) – Translation vector with shape (3,).

  • fractional (bool, optional) – If True, the translation is interpreted as fractional coordinates and converted to Cartesian before applying. Default is False (translation is in Cartesian Angstroms).

Returns:

Self, for method chaining.

Return type:

Model

Examples

# Translate by 5 Angstroms along X
model.translate(torch.tensor([5.0, 0.0, 0.0]))

# Translate by half a unit cell along each axis
model.translate(torch.tensor([0.5, 0.5, 0.5]), fractional=True)
get_centroid()[source]

Compute the centroid (center of mass) of all atoms.

Returns:

Centroid coordinates with shape (3,).

Return type:

torch.Tensor

use_internal_coordinates(n_aa_per_segment=5, bond_cutoff=2.0, cif_dict=None, requires_grad=True)[source]

Switch xyz to segmented internal coordinate parametrization.

Replaces the current xyz MixedTensor with a SegmentedInternalCoordinateTensor that parametrizes atomic positions using bond lengths, angles, torsion angles, and per-segment rigid body parameters. The molecule is broken into independent segments to avoid the “lever arm problem” where small torsion changes near the root cause large displacements at distant atoms.

Parameters:
  • n_aa_per_segment (int, optional) – Number of amino acids per segment. Default is 5. - Smaller values (1-2): More segments, shallower trees, less lever arm - Larger values (5-10): Fewer segments, deeper trees, more lever arm

  • bond_cutoff (float, optional) – Distance cutoff for bond detection in Angstroms. Default is 2.0. Only used when cif_dict is not provided.

  • cif_dict (dict, optional) – CIF dictionary containing bond definitions per residue type. If provided, bonds are determined from chemical definitions rather than distances, which is more robust for structures with poor geometry. Expected format: cif_dict[resname][‘bonds’] DataFrame with ‘atom1’, ‘atom2’.

  • requires_grad (bool, optional) – Whether internal coordinate parameters should have gradients. Default is True.

Returns:

Self, for method chaining.

Return type:

Model

Examples

model = Model()
model.load_pdb('structure.pdb')
model.use_internal_coordinates(n_aa_per_segment=3)

# Now model.xyz() returns coordinates reconstructed from
# segmented internal coordinates

# Shake the structure using internal coordinates
new_xyz = model.xyz.shake(magnitude=0.1)

# Each segment has independent internal coordinates and
# rigid body parameters (position + orientation)

Notes

After calling this method, model.xyz will be a SegmentedInternalCoordinateTensor instead of a MixedTensor. This provides: - Shallow spanning trees within segments (depth ~10-30 vs ~1000) - Independent segments that don’t propagate changes to distant atoms - Rigid body parameters (position + orientation) per segment - forward() / __call__(): Reconstruct Cartesian coordinates - shake(magnitude): Add noise to internal parameters - Gradient flow through all internal coordinate parameters

class torchref.ModelFT(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]

Bases: CachedForwardMixin, Model

Model subclass for Fourier Transform-based electron density and structure factor calculations.

ModelFT extends the base Model class with capabilities for computing electron density maps in real space and structure factors via FFT. Uses ITC92 parametrization for electron density calculations.

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.

  • radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.

  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed from cell and max_res.

  • wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.

  • anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.

  • *args – Additional positional arguments passed to parent Model class.

  • **kwargs – Additional keyword arguments passed to parent Model class.

max_res

Maximum resolution for grid spacing.

Type:

float

radius_angstrom

Radius for density calculation.

Type:

float

wavelength

X-ray wavelength for anomalous scattering corrections.

Type:

float or None

anomalous_threshold

Threshold for significant anomalous scattering (electrons).

Type:

float

gridsize

Grid dimensions (nx, ny, nz).

Type:

torch.Tensor

real_space_grid

Real-space coordinate grid with shape (nx, ny, nz, 3).

Type:

torch.Tensor

map

Computed electron density map.

Type:

torch.Tensor or None

parametrization

ITC92 parametrization dictionary {element: (A, B, C)}.

Type:

dict

map_symmetry

Symmetry operator for map calculations.

Type:

MapSymmetry

Examples

Empty initialization for state_dict loading:

model = ModelFT()
model.load_state_dict(torch.load('model.pt'))

File-based initialization:

model = ModelFT(max_res=1.5)
model.load_pdb('structure.pdb')
__init__(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]

Initialize an empty ModelFT shell.

Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.

  • radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 4.0.

  • gridsize (tuple of int, optional) – Explicit grid size tuple (nx, ny, nz). If None, computed automatically.

  • wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.

  • anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.

  • *args – Passed to parent Model class.

  • **kwargs – Passed to parent Model class.

property cell

Unit cell object with parameters [a, b, c, alpha, beta, gamma].

property spacegroup

Space group object.

load_pdb(filename)[source]

Load a PDB file and initialize the model with FT-specific setup.

Parameters:

filename (str) – Path to the PDB file.

Returns:

Self, for method chaining.

Return type:

ModelFT

select(selection)[source]

Return a new Model containing only atoms matching the Phenix-style selection.

Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.

Parameters:

selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping

Returns:

New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.

Return type:

Model

Raises:
  • RuntimeError – If the model has not been initialized.

  • ValueError – If selection syntax is invalid or no atoms are selected.

Examples

model = Model().load_pdb('structure.pdb')
# Select chain A
chain_a = model.select("chain A")
# Select backbone atoms
backbone = model.select("name CA or name C or name N or name O")
# Select residues 10-50 of chain B
region = model.select("chain B and resseq 10:50")
# Select all except water
no_water = model.select("not resname HOH")
# Complex selection with parentheses
complex_sel = model.select("chain A and (resname ALA or resname GLY)")

Notes

This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.

load_cif(filename)[source]

Load a CIF file and initialize the model with FT-specific setup.

Parameters:

filename (str) – Path to the CIF/mmCIF file.

Returns:

Self, for method chaining.

Return type:

ModelFT

setup_gridsize(max_res=None)[source]

Compute optimal grid dimensions.

Delegates to FFT.compute_grid_size().

Parameters:

max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.

Returns:

Grid dimensions (nx, ny, nz) as int32 tensor.

Return type:

torch.Tensor

property A: Tensor

ITC92 A parameters (amplitudes) for all atoms.

Returns:

A parameters with shape (n_atoms, 5).

Return type:

torch.Tensor

property B: Tensor

ITC92 B parameters (widths) for all atoms.

Returns:

B parameters with shape (n_atoms, 5).

Return type:

torch.Tensor

property gridsize: Tensor | None

Grid dimensions (nx, ny, nz).

property real_space_grid: Tensor | None

Real-space coordinate grid with shape (nx, ny, nz, 3).

property voxel_size: Tensor | None

Voxel dimensions.

property map_symmetry: MapSymmetry | None

Symmetry operator for map calculations.

get_iso()[source]

Get isotropic atoms with their ITC92 parameters.

Returns:

  • xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).

  • adp (torch.Tensor) – Atomic displacement parameters (isotropic) with shape (n_atoms,).

  • occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).

get_aniso()[source]

Get anisotropic atoms with their ITC92 parameters.

Returns:

  • xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).

  • u (torch.Tensor) – Anisotropic U parameters with shape (n_atoms, 6).

  • occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).

  • A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).

  • B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).

setup_grid(max_res=None, gridsize=None)[source]

Setup real-space grid for electron density calculation.

Delegates to FFT.setup_grid() using the stored cell and spacegroup.

Parameters:
  • max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. If None, uses self.max_res.

  • gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed automatically using Cell.compute_grid_size() and SpaceGroup.suggest_grid_size().

get_radius(min_radius_Angstrom=4.0)[source]

Get the radius in voxels used for density calculation around each atom.

Parameters:

min_radius_Angstrom (float, optional) – Minimum radius in Angstroms. Default is 4.0.

Returns:

Radius in voxels.

Return type:

int

build_complete_map(radius=None, apply_symmetry=True)[source]

Build electron density map from all atoms.

Uses get_iso() and get_aniso() to get atom data and constructs the complete electron density map.

Parameters:
  • radius (int, optional) – Radius in voxels around each atom to compute density. If None, uses self.radius.

  • apply_symmetry (bool, optional) – If True and space group is not P1, apply symmetry operations to the map. Default is True.

Returns:

Electron density map with symmetry applied if requested.

Return type:

torch.Tensor

build_initial_map(apply_symmetry=True)[source]

Build electron density map from atomic parameters.

Delegates to FFT.build_density_map() using the model’s stored parameters.

Parameters:

apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.

Returns:

Electron density map with shape (nx, ny, nz).

Return type:

torch.Tensor

save_map(filename)[source]

Save the electron density map to a CCP4 format file.

Parameters:

filename (str) – Output filename for the map.

Raises:

ValueError – If no map has been computed yet.

get_map_statistics()[source]

Get statistics about the current density map.

rebuild_map(radius=None)[source]

Rebuild the density map from scratch.

Convenience method that clears and rebuilds everything.

Parameters:

radius (int, optional) – Radius in voxels around each atom. If None, uses self.radius. If specified, overrides self.radius.

Returns:

Rebuilt electron density map.

Return type:

torch.Tensor

update_pdb()[source]

Update PDB with current atomic parameters.

reset_cache()[source]

Reset SF cache, anomalous cache, and all wrapper forward caches.

invalidate_cache()[source]

Alias for reset_cache().

get_structure_factor(hkl, recalc=False, apply_anomalous=True)[source]

Get structure factors for given hkl reflections.

Uses CachedForwardMixin to cache the result and auto-invalidate when parameters change or a backward pass propagates through.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • recalc (bool, optional) – If True, forces recalculation bypassing the cache. Default is False.

  • apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.

Returns:

Complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

Notes

The complete scattering factor is:

f(s, λ) = f₀(s) + f’(λ) + i·f’’(λ)

where f₀ is the normal (Thomson) scattering factor computed via FFT, and f’/f’’ are the wavelength-dependent anomalous corrections.

Anomalous corrections are only computed for atoms where |f'| > anomalous_threshold or |f''| > anomalous_threshold.

property fft

Access the SfFFT submodule.

forward(hkl, apply_anomalous=True)[source]

Compute structure factors for given hkl.

This is called by the mixin’s __call__ which handles caching, backward-hook registration, and auto-invalidation.

Parameters:
  • hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).

  • apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.

Returns:

Calculated complex structure factors with shape (n_reflections,).

Return type:

torch.Tensor

copy(detach=True)[source]

Create a deep copy of the ModelFT.

Creates a complete independent copy including all Model base class data, FFT submodule state (gridsize, real_space_grid, voxel_size, map_symmetry), ITC92 parametrization, and scalar attributes. Cache is reset to empty.

Parameters:

detach (bool, optional) – If True, the copy’s parameters will be detached from the computation graph (default: True).

Returns:

A new ModelFT instance with copied data.

Return type:

ModelFT

Examples

model = ModelFT().load_pdb('structure.pdb')
model_copy = model.copy()
# model_copy is independent, changes won't affect model
state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the ModelFT.

Extends parent Model.state_dict() with FT-specific parameters including max_res, radius_angstrom. Grid state is handled by the FFT submodule.

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, optional) – Prefix for parameter names. Default is ‘’.

  • keep_vars (bool, optional) – Whether to keep variables in computational graph. Default is False.

Returns:

Complete state dictionary.

Return type:

dict

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]

Create a fully initialized ModelFT from a state dictionary.

This is the recommended way to restore a ModelFT from a saved state. Creates an instance with properly initialized submodules, then loads the state.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • dtype_float (torch.dtype, optional) – Float dtype for tensors. Default is dtypes.float.

Returns:

Fully initialized instance with restored state.

Return type:

ModelFT

class torchref.Refinement(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]

Bases: DeviceMixin, DebugMixin, Module

Refinement class to handle the overall crystallographic refinement process.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    refinement = Refinement()  # Creates empty shell with submodules
    refinement.load_state_dict(torch.load('refinement.pt'))
    
  2. Full initialization with file paths:

    refinement = Refinement(data_file='data.mtz', pdb='model.pdb')
    
Parameters:
  • data_file (str, optional) – Path to MTZ or CIF file containing reflection data.

  • pdb (str, optional) – Path to PDB or CIF file containing initial model.

  • cif (str, optional) – Path to CIF file for restraints.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • max_res (float, optional) – Maximum resolution for reflections.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.

  • nbins (int, optional) – Number of resolution bins. Default is 10.

device

Computation device.

Type:

torch.device

verbose

Verbosity level.

Type:

int

reflection_data

Reflection data container.

Type:

ReflectionData

model

Structure factor model (includes lazy restraints via model.restraints).

Type:

ModelFT

scaler

Scale factor calculator.

Type:

Scaler

weighter

Loss weighting module.

Type:

LossWeightingModule

__init__(data_file=None, pdb=None, cif=None, verbose=1, max_res=None, device=None, nbins=10, manual_weights=None, component_weights=None, column_names=None)[source]

Initialize Refinement.

If data_file and pdb are provided, fully initializes the refinement. If not provided (empty init), creates a shell with empty submodules ready for load_state_dict().

Parameters:
  • data_file (str, optional) – Path to MTZ or CIF file containing reflection data.

  • pdb (str, optional) – Path to PDB or CIF file containing initial model.

  • cif (str, optional) – Path to CIF file for restraints.

  • verbose (int, optional) – Verbosity level. Default is 1.

  • max_res (float, optional) – Maximum resolution for reflections.

  • device (torch.device, optional) – Computation device. Defaults to the configured device.current.

  • weighter (LossWeightingModule, optional) – Loss weighting module. Creates default if None.

  • nbins (int, optional) – Number of resolution bins. Default is 10.

set_xray_target_mode(mode)[source]

Change the X-ray target mode.

Parameters:

mode (str) – X-ray target mode. Options are ‘gaussian’, ‘ls’, or ‘ml’.

property data

Expose reflection_data as ‘data’ for weighting module compatibility.

Returns:

The reflection data container.

Return type:

ReflectionData

property loss_state: LossState

Get or create the persistent LossState.

The LossState is created once and reused across refinement cycles. Targets are registered once; weights are updated each cycle.

Returns:

The persistent loss state with targets registered.

Return type:

LossState

property logger: Logger

Get or create the Logger for this refinement.

Returns:

Logger instance linked to the persistent LossState.

Return type:

Logger

reset_loss_state()[source]

Reset the persistent LossState and Logger.

Call this if targets need to be re-registered (e.g., after changing target modes or reinitializing targets).

get_scales()[source]
setup_scaler()[source]
parameters(recurse=True)[source]

Return unique parameters from this module and all submodules.

Uses the default Module.parameters() to gather parameters, then removes duplicates while preserving order to avoid passing the same tensor multiple times to the optimizer.

Parameters:

recurse (bool, optional) – If True, yields parameters of this module and all submodules. Default is True.

Returns:

List of unique parameter tensors.

Return type:

list

get_fcalc(hkl=None, recalc=False)[source]
get_fcalc_scaled(hkl=None, recalc=False)[source]
adp_loss()[source]

Compute total ADP loss using TotalADPTarget.

This combines:

  • Bond-based similarity (SIMU-like)

  • Spread control (tighter than KL)

  • Bounds penalty

Returns:

Total ADP loss value.

Return type:

torch.Tensor

get_F_calc(hkl=None, recalc=False)[source]
get_F_calc_scaled(hkl=None, recalc=False)[source]
nll_xray()[source]

Compute X-ray negative log-likelihood for work and test sets.

Returns:

Tuple of (work_nll, test_nll) tensors.

Return type:

tuple of torch.Tensor

xray_loss_work()[source]

Compute X-ray loss on work set using instantiated target.

Returns:

X-ray loss on work set.

Return type:

torch.Tensor

xray_loss_test()[source]

Compute X-ray loss on test set using instantiated target.

Returns:

X-ray loss on test set.

Return type:

torch.Tensor

bond_loss()[source]

Compute bond length NLL via geometry_target.

Returns:

Bond length NLL loss.

Return type:

torch.Tensor

angle_loss()[source]

Compute angle NLL via geometry_target.

Returns:

Angle NLL loss.

Return type:

torch.Tensor

torsion_loss()[source]

Compute torsion angle NLL via geometry_target.

Returns:

Torsion angle NLL loss.

Return type:

torch.Tensor

geometry_loss()[source]

Compute total geometry NLL using TotalGeometryTarget.

Returns:

Total geometry NLL loss.

Return type:

torch.Tensor

loss()[source]

Compute total loss using LossState pipeline.

Creates a LossState, populates meta, caches losses, updates weights, and returns the aggregated weighted loss.

Returns:

Total weighted loss.

Return type:

torch.Tensor

setup_component_weighting()[source]

Set up component weighting with ResolutionWeighting + OverfittingWeighting.

populate_state_meta(state)[source]

Populate LossState.meta with all model-level data.

Called once per macro cycle before weighting schemes are applied. This is the single location where refinement data is extracted into state.

Parameters:

state (LossState) – State to populate with meta data.

Returns:

State with meta populated.

Return type:

LossState

update_weights(state, multiply=False)[source]

Compute weights from component_weighting and update state. Weights are clipped to [0.01, 100.0] to avoid extreme values.

Parameters:
  • state (LossState) – State with meta populated.

  • multiply (bool, optional) – If True, multiply existing weights by computed weights. If False, replace existing weights with computed weights.

Returns:

State with weights updated.

Return type:

LossState

create_loss_state()[source]

Create a configured LossState for optimization.

Deprecated since version Use: the loss_state property instead for the persistent state. This method is kept for backwards compatibility.

Sets up a LossState with all targets registered as callables with hierarchical naming (e.g., ‘geometry/bond’, ‘adp/simu’). Weights are applied from component_weighting.

Usage:

from torchref.utils import validate_loss

state = refinement.create_loss_state() params = list(refinement.parameters())

# Log initial state state.aggregate(log_values=True)

# In an LBFGS closure, wrap with validate_loss so non-finite # losses warn + reject the step instead of poisoning the run. def closure():

optimizer.zero_grad() loss = state.aggregate() loss.backward() ok = validate_loss(

loss, state=state, parameters=params, context=”my_refinement”, raise_on_fail=False,

) if not ok:

for p in params:
if p.grad is not None:

p.grad.zero_()

return torch.full_like(loss.detach(), float(“inf”))

return loss

optimizer.step(closure)

# Log final state state.new_entry() state.aggregate(log_values=True)

Returns:

Configured LossState with targets and weights.

Return type:

LossState

complete_loss_state()[source]

Update and return the persistent LossState.

Updates the persistent LossState with current meta, target info, cached losses, and weights. The state is reused across cycles.

The cached active-parameter leaf set is not refreshed here. Stale leaves are not a correctness hazard: a leaf that’s in the set but whose Parameter object was replaced externally (e.g. by Model.freeze) just gets ignored by _freeze_graph_extras, which costs a marginal amount of wasted backward work but never produces wrong answers. If you do call Model.freeze / Model.unfreeze between LossState creation and a refinement step, call state.refresh_loss_leaves() explicitly.

Returns:

Complete LossState with targets, meta, losses, and weights.

Return type:

LossState

xray_loss()[source]

Compute X-ray loss on work set.

Returns:

X-ray loss on work set.

Return type:

torch.Tensor

restraints_loss()[source]

Compute total geometry restraints loss.

Returns:

Total geometry restraints loss.

Return type:

torch.Tensor

collect_metrics()[source]

Collect all metrics from component_weighting.stats().

This is the standard method for gathering refinement metrics for logging. Uses the centralized component_weighting module for all statistics. Returns full unfiltered stats - filtering is done at display time.

Returns:

Dictionary with all metrics (unfiltered, with StatEntry objects).

Return type:

dict

add_target_info_to_state(state)[source]

Add target information from geometry and ADP targets to LossState.meta.

Deprecated since version This: method is no longer needed. Use complete_loss_state() instead, which handles all state setup in one call.

Parameters:

state (LossState) – Current loss state. Meta will be updated with target info.

Returns:

Updated loss state (unchanged).

Return type:

LossState

get_rfactor()[source]
update_outliers(z_threshold=4.0)[source]
plot_fcalc_vs_fobs(outpath='fcalc_vs_fobs.png')[source]
write_out_mtz(out_mtz_path='refined_output.mtz')[source]
collect_deposition_metadata(metadata=None)[source]

Collect refinement statistics into a RefinementMetadata object.

Reuses existing statistics from collect_metrics(), get_rfactor(), and reflection data attributes.

Parameters:

metadata (RefinementMetadata, optional) – Existing metadata to merge with (e.g. from input file pass-through). Refinement statistics take precedence over pass-through values.

Returns:

Metadata populated with final refinement statistics.

Return type:

RefinementMetadata

write_out_pdb(out_pdb_path='refined_output.pdb', metadata=None)[source]

Write refined PDB with optional metadata header.

Parameters:
  • out_pdb_path (str) – Output PDB file path.

  • metadata (RefinementMetadata, optional) – Metadata for PDB header. If None, auto-collected from refinement.

write_out_cif(out_cif_path='refined_output.cif', metadata=None)[source]

Write refined coordinates as mmCIF with metadata.

Parameters:
  • out_cif_path (str) – Output mmCIF file path.

  • metadata (RefinementMetadata, optional) – Metadata for mmCIF categories. If None, auto-collected from refinement.

save_state(path)[source]

Save the complete state of the refinement to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the refinement from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, optional) – Whether to strictly enforce that keys match. Default is True.

classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1)[source]

Create a fully initialized Refinement from a state dictionary.

This is the recommended way to restore a Refinement from a saved state. It creates the proper submodules using their respective create_from_state_dict methods, then calls PyTorch’s default load_state_dict.

Parameters:
  • state_dict (dict) – State dictionary from torch.save(refinement.state_dict(), …) or from loading a checkpoint file.

  • device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.

  • verbose (int, optional) – Verbosity level. Default is 1.

Returns:

Fully initialized instance with restored state.

Return type:

Refinement

Examples

Save and load refinement state:

# Save
torch.save(refinement.state_dict(), 'refinement.pt')

# Load
state = torch.load('refinement.pt')
refinement = Refinement.create_from_state_dict(state)

# Continue refinement
rwork, rfree = refinement.get_rfactor()
print(f"Restored at R-work={rwork:.4f}, R-free={rfree:.4f}")
class torchref.LBFGSRefinement(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]

Bases: Refinement

LBFGS-based refinement subclass using the L-BFGS optimizer for fast convergence.

L-BFGS (Limited-memory BFGS) is a quasi-Newton optimization method that approximates the Hessian matrix, leading to much faster convergence than first-order methods.

Key advantages:

  • Converges in 1-2 macro cycles (vs 5+ for Adam)

  • Better final R-factors

  • More stable convergence

  • Automatically handles step size via line search

Parameters:
  • target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, or ‘ml’). Default is ‘ml’.

  • *args – Passed to parent Refinement class.

  • **kwargs – Passed to parent Refinement class.

target_mode

Current X-ray target mode.

Type:

str

Examples

Basic usage:

from torchref.refinement import LBFGSRefinement

refinement = LBFGSRefinement(
    data_file='data.mtz',
    pdb='model.pdb',
    target_mode='ml'
)
refinement.refine(macro_cycles=2)
LBFGS_DEFAULTS = {'history_size': 100, 'line_search_fn': 'strong_wolfe', 'lr': 1.0, 'max_iter': 20}
__init__(*args, target_mode='bhattacharyya', sigma_m_scale=1.0, use_lossstate_scaler=True, **kwargs)[source]

Initialize LBFGS refinement.

Parameters:
  • target_mode (str, optional) – X-ray target mode (‘gaussian’, ‘ls’, ‘ml’, ‘bhattacharyya’). Default is ‘bhattacharyya’.

  • sigma_m_scale (float, optional) – Global multiplier for σ_m in the Bhattacharyya target only. Ignored for other target modes. Default 1.0.

  • use_lossstate_scaler (bool, optional) – If True (default), refine_scaler() uses the full LossState with the body’s x-ray target — so scaler and body steps share one consistent loss. If False, falls back to Scaler.refine_lbfgs which minimises a standalone nll_xray and can pull scales in a different direction than the body optimization.

  • *args – Passed to parent Refinement class.

  • **kwargs – Passed to parent Refinement class.

xray_loss()[source]

Compute X-ray loss using the instantiated target.

Returns:

X-ray loss on work set.

Return type:

torch.Tensor

refine_scaler()[source]

Refine scaler parameters against the full refinement loss.

Builds the body LossState via complete_loss_state(), constructs a fresh LBFGS optimizer over list(self.scaler.parameters()), and delegates to LossState.step(). Because state.step disables requires_grad on every loss leaf outside the optimizer’s intent set, xyz / adp / u / occupancy are pinned for the duration — only scaler parameters move.

The critical property is that the x-ray target used here is the same one the body refine_xyz() and refine_adp() see. The legacy Scaler.refine_lbfgs() minimises a standalone nll_xray + U^2 penalty, which can pull scales in a different direction than a bhattacharyya or ml body loss and leaves the body to chase a scaler that disagrees with its own objective.

When use_lossstate_scaler is False, fall back to the legacy Scaler.refine_lbfgs() path.

Returns:

LossState with history if use_lossstate_scaler is True, otherwise the metrics dict from Scaler.refine_lbfgs().

Return type:

LossState or dict

refine_xyz()[source]

Refine Cartesian coordinates jointly with scaler parameters.

Scaler parameters (log_scale, U, solvent terms) are included in the same LBFGS call as xyz. The joint curvature lets xyz steps see the scaler as an anchor — residuals the scaler can absorb do not have to be chased by atomic motion — and the adp/scaler_U and adp/scaler_log_scale priors bite on every step, so nothing in the scaler drifts between refine_xyz and refine_adp calls.

Returns:

State with history containing before/after loss values.

Return type:

LossState

refine_adp()[source]

Refine ADP / U / occupancy jointly with scaler parameters.

Scaler parameters (log_scale, U, solvent terms) are included in the same LBFGS call as the ADP-block body parameters so the joint curvature can slide along the atomic-B / scaler-U degeneracy ridge together with the adp/scaler_U regularizer. XYZ is left frozen.

Returns:

State with history containing before/after loss values.

Return type:

LossState

refine_joint()[source]

Joint LBFGS over every refinable parameter in one step.

Optimizes xyz, adp, u, occupancy, and every scaler parameter (log_scale, anisotropic U, solvent terms) in a single LBFGS call. The joint curvature couples all of them through the same x-ray target and through the adp/scaler_U / adp/scaler_log_scale priors — unlike alternating refine_xyz → refine_adp, there’s no “frozen partner” in either half that could lock the step into a locally bad direction.

Returns:

State with history containing before/after loss values.

Return type:

LossState

run_training_trajectory(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]

Run a training trajectory with policy-guided refinement.

This method runs a sequence of refinement steps using a policy to select component weights. It records state-action-reward tuples for training the policy with AWR or similar algorithms.

Parameters:
  • policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode with sampling).

  • n_steps (int, optional) – Number of refinement steps in the trajectory (default: 10).

  • pdb_id (str, optional) – PDB identifier for recording.

  • structure_path (str, optional) – Path to structure file for recording.

  • sf_path (str, optional) – Path to structure factors file for recording.

  • seed (int, optional) – Random seed for reproducibility.

  • policy_version (str, optional) – Version identifier of the policy being used.

Returns:

Complete trajectory with state-action-reward tuples.

Return type:

TrajectoryData

run_training_trajectory_joint(policy_weighting, n_steps=10, pdb_id='', structure_path='', sf_path='', seed=None, policy_version=None)[source]

Run a training trajectory with joint XYZ+ADP refinement.

Similar to run_training_trajectory() but refines xyz, adp, u, and occupancy together in each step. The LBFGS curvature history is reset at the start of each policy step because the weight updates invalidate any prior Hessian approximation.

Parameters:
  • policy_weighting (PolicyComponentWeighting) – Policy weighting scheme (should be in training mode).

  • n_steps (int, optional) – Number of refinement steps (default: 10).

  • pdb_id (str, optional) – Identifiers for trajectory recording.

  • structure_path (str, optional) – Identifiers for trajectory recording.

  • sf_path (str, optional) – Identifiers for trajectory recording.

  • seed (int, optional) – Random seed for reproducibility.

  • policy_version (str, optional) – Policy version identifier.

Returns:

Complete trajectory with state-action-reward tuples.

Return type:

TrajectoryData

refine(macro_cycles=5)[source]

Run full LBFGS refinement cycle (ADP + XYZ).

Parameters:

macro_cycles (int, optional) – Number of refinement cycles to perform. Default is 5.

Returns:

History dictionary with all metrics per cycle (hierarchical structure).

Return type:

dict

refine_everything(macro_cycles=5)[source]

Run full LBFGS refinement cycle (ADP + XYZ) without weight screening.

Parameters:

macro_cycles (int, optional) – Number of refinement cycles to perform. Default is 5.

Returns:

History dictionary with all metrics per cycle (hierarchical structure).

Return type:

dict

class torchref.Scaler(model=None, data=None, nbins=20, verbose=1, device=None)[source]

Bases: ScalerBase

Full-featured scaler with Model integration.

Extends ScalerBase by maintaining a reference to a Model object and providing convenience methods that automatically compute F_calc when not provided.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    scaler = Scaler()  # Creates empty shell
    scaler.load_state_dict(torch.load('scaler.pt'))
    
  2. Full initialization with model and data:

    scaler = Scaler(model, reflection_data, nbins=20)
    scaler.initialize()
    
Parameters:
  • model (Model, optional) – Model object for structure factor calculation.

  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, default: configured device.current) – Computation device.

device

Current computation device.

Type:

torch.device

nbins

Number of resolution bins.

Type:

int

__init__(model=None, data=None, nbins=20, verbose=1, device=None)[source]

Initialize Scaler.

If model and data are provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • model (Model, optional) – Model object for structure factor calculation.

  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from model then data (model wins on mismatch); otherwise forces both onto the explicit device. See torchref.utils.resolve_device().

property model

Access the model object (not a registered submodule).

set_model_and_data(model, data)[source]

Set model and data references after empty initialization.

This is useful when loading from state_dict and then needing to reconnect to model/data objects.

Parameters:
  • model (Model) – Model object for structure factor calculation.

  • data (ReflectionData) – ReflectionData object with observed data.

initialize(fcalc=None)[source]

Initialize scaling parameters.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

compute_fcalc()[source]

Compute F_calc from internal model.

Returns:

Calculated structure factors.

Return type:

torch.Tensor

Raises:

RuntimeError – If no model is set.

calc_initial_scale(fcalc=None)[source]

Calculate initial scale factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

The log scale parameter for each resolution bin.

Return type:

torch.nn.Parameter

fit_anisotropy(fcalc=None, nsteps=100)[source]

Fit anisotropic correction.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • nsteps (int, default 100) – Number of optimization steps.

setup_solvent()[source]

Setup solvent model using internal model.

Creates a SolventModel using the internal model reference.

fit_all_scales(fcalc=None)[source]

Fit all scale parameters.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

screen_solvent_params(fcalc=None, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]

Screen solvent parameters using grid search.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • steps (int, default 15) – Number of grid points for each parameter.

  • use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily.

  • low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.

  • fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.

  • low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.

refine_lbfgs(fcalc=None, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS optimizer.

If fcalc is not provided, computes it from the internal model.

Parameters:
  • fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

  • nsteps (int, default 3) – Number of LBFGS steps.

  • lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int, default 200) – Maximum iterations per line search.

  • history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.

  • verbose (bool, default True) – Print progress information.

Returns:

Dictionary with refinement metrics.

Return type:

dict

rfactor(fcalc=None)[source]

Calculate R-factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

R-work and R-free values.

Return type:

tuple

bin_wise_rfactor(fcalc=None)[source]

Calculate bin-wise R-factors.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

  • mean_res_per_bin (torch.Tensor) – Mean resolution per bin.

  • rwork_per_bin (torch.Tensor) – R-work per bin.

  • rfree_per_bin (torch.Tensor) – R-free per bin.

get_binwise_mean_intensity(fcalc=None)[source]

Get bin-wise mean intensities.

If fcalc is not provided, computes it from the internal model.

Parameters:

fcalc (torch.Tensor, optional) – Calculated structure factors. If None, computed from model.

Returns:

Mean observed intensity, mean calculated intensity, and mean resolution per bin.

Return type:

tuple

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the Scaler.

This includes:

  • All registered buffers and parameters (via parent class)

  • Scaler-specific metadata (nbins, etc.)

  • Solvent model state (if initialized)

Note: Model and data references are NOT saved (managed separately).

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, default '') – Prefix for parameter names.

  • keep_vars (bool, default False) – Whether to keep variables in computational graph.

Returns:

Complete state dictionary.

Return type:

dict

load_state_dict(state_dict, strict=True)[source]

Load the Scaler state from a dictionary.

Note: This assumes model and data are already set via __init__ or assignment.

Parameters:
  • state_dict (dict) – Dictionary containing scaler state.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

class torchref.ScalerBase(data=None, nbins=20, verbose=1, device=None)[source]

Bases: DeviceMixin, DebugMixin, Module

Base scaler class for crystallographic scaling without model dependency.

All methods that require calculated structure factors (F_calc) take them as input arguments. This allows the scaler to be used independently of any specific model implementation.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    scaler = ScalerBase()  # Creates empty shell
    scaler.load_state_dict(torch.load('scaler.pt'))
    
  2. Full initialization with data:

    scaler = ScalerBase(data=reflection_data, nbins=20)
    scaler.initialize(fcalc)
    
Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, default: configured device.current) – Computation device.

device

Current computation device.

Type:

torch.device

nbins

Number of resolution bins.

Type:

int

__init__(data=None, nbins=20, verbose=1, device=None)[source]

Initialize ScalerBase.

If data is provided, fully initializes the scaler. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • data (ReflectionData, optional) – ReflectionData object with observed data.

  • nbins (int, default 20) – Number of resolution bins.

  • verbose (int, default 1) – Verbosity level.

  • device (torch.device, optional) – Computation device. If None, derived from data (if given) or the configured default via torchref.utils.resolve_device(). An explicit value forces data onto that device.

set_data(data)[source]

Set data reference after empty initialization.

This is useful when loading from state_dict and then needing to reconnect to a data object.

Parameters:

data (ReflectionData) – ReflectionData object with observed data.

initialize(fcalc)[source]

Initialize scaling parameters using provided F_calc.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

property hkl

Get HKL indices from data.

calc_initial_scale(fcalc)[source]

Calculate the initial scale factor based on the ratio of observed to calculated structure factors.

Excludes reflections with negative intensities to avoid bias from French-Wilson conversion.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

The log scale parameter for each resolution bin.

Return type:

torch.nn.Parameter

setup_anisotropy_correction()[source]

Initialize anisotropic correction parameters.

anisotropy_correction()[source]

Compute anisotropic correction factors.

Returns:

Anisotropic correction factors for each reflection.

Return type:

torch.Tensor

fit_anisotropy(fcalc, nsteps=100)[source]

Fit anisotropic correction using provided F_calc.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • nsteps (int, default 100) – Number of optimization steps.

set_solvent_model(solvent_model)[source]

Set a pre-configured SolventModel for solvent contribution.

The SolventModel must be initialized externally (requires a Model object).

Parameters:

solvent_model (SolventModel) – Pre-configured solvent model that can compute solvent structure factors.

setup_binwise_solvent_scale()[source]

Setup bin-wise solvent scaling (Phenix-style kmask per bin).

This allows finer control over solvent contribution per resolution bin, which is more flexible than a single global B_sol parameter.

fit_all_scales(fcalc)[source]

Fit all scale parameters using provided F_calc.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

fit_simple(fobs, fcalc)[source]

Fit a single global scale factor analytically (least-squares).

This is the simple scaling approach:

k = sum(|F_obs||F_calc|) / sum(|F_calc|²)

Useful for rigid body refinement where only an overall scale is needed.

Parameters:
  • fobs (torch.Tensor) – Observed structure factor amplitudes.

  • fcalc (torch.Tensor) – Calculated structure factors (complex).

get_scale()[source]

Get the current overall scale factor value.

Returns the mean scale factor across all bins.

Returns:

Current scale factor (not log).

Return type:

float

rfactor(fcalc)[source]

Calculate the R-factor between observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

R-work and R-free values.

Return type:

tuple

bin_wise_rfactor(fcalc)[source]

Calculate the bin-wise R-factor between observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

  • mean_res_per_bin (torch.Tensor) – Mean resolution per bin.

  • rwork_per_bin (torch.Tensor) – R-work per bin.

  • rfree_per_bin (torch.Tensor) – R-free per bin.

setup_bin_wise_bfactor()[source]

Initialize bin-wise B-factor correction parameters.

bin_wise_bfactor_correction()[source]

Compute bin-wise B-factor correction factors.

Returns:

B-factor correction factors for each reflection.

Return type:

torch.Tensor

get_binwise_mean_intensity(fcalc)[source]

Get bin-wise mean intensities for observed and calculated structure factors.

Parameters:

fcalc (torch.Tensor) – Calculated structure factors (complex).

Returns:

Mean observed intensity, mean calculated intensity, and mean resolution per bin.

Return type:

tuple

screen_solvent_params(fcalc, steps=15, use_low_res_weighting=True, low_res_cutoff=5.0, fit_on_low_res_only=True, low_res_limit=3.5)[source]

Screen solvent parameters (k_sol, B_sol) using grid search.

The bulk solvent contributes primarily at low resolution. Fitting on low-resolution reflections only (fit_on_low_res_only=True) prevents high-resolution reflections from dominating the optimization and pushing B_sol too low.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • steps (int, default 15) – Number of grid points for each parameter.

  • use_low_res_weighting (bool, default True) – If True, weight low-resolution reflections more heavily since solvent primarily contributes at low resolution.

  • low_res_cutoff (float, default 5.0) – Resolution cutoff for weighting in Angstroms.

  • fit_on_low_res_only (bool, default True) – If True, fit using only low-resolution reflections.

  • low_res_limit (float, default 3.5) – Resolution limit for low-res only fitting in Angstroms.

refine_lbfgs(fcalc, nsteps=3, lr=1.0, max_iter=200, history_size=10, verbose=True)[source]

Refine scale parameters using LBFGS optimizer.

This method optimizes the anisotropic scaling and B-factor parameters that relate calculated structure factors to observed structure factors. Uses the L-BFGS quasi-Newton optimization method for fast convergence.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex).

  • nsteps (int, default 3) – Number of LBFGS steps.

  • lr (float, default 1.0) – Learning rate (typically 1.0 for LBFGS).

  • max_iter (int, default 200) – Maximum iterations per line search.

  • history_size (int, default 10) – Number of previous gradients to store for Hessian approximation.

  • verbose (bool, default True) – Print progress information.

Returns:

Dictionary with refinement metrics including steps, xray_work, xray_test, rwork, rfree.

Return type:

dict

estimate_sigma_eff(fcalc, max_inflation=2.0)[source]

Estimate per-resolution-shell effective sigmas from current residuals.

Pannu & Read / SIGMAA-style correction: detects miscalibrated experimental sigmas by comparing residual variance to the claimed variance, per resolution bin.

For each resolution bin:

D_bin = < (F_obs - k * |F_calc|)^2 > (using work set) ratio_bin = sqrt(D_bin / <sigma_F^2>) ratio_capped = clamp(ratio_bin, 1.0, max_inflation) sigma_eff = sigma_F * ratio_capped

Why the cap? At the start of refinement the model is bad, so residuals are dominated by model error (which is fixable by refining), not noise. Uncapped inflation creates a vicious cycle: bad model -> huge sigma_eff -> weak data gradient -> bad model. Capping at max_inflation (default 2.0, i.e. sigmas can grow at most 2x) prevents runaway while still correcting genuinely under-estimated sigmas.

As the model improves, residuals shrink and the ratio drops toward 1, so sigma_eff converges to the raw sigma (good calibration).

Uses the work set only so the test set doesn’t leak into sigma estimation.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors (complex, unscaled).

  • max_inflation (float, optional) – Maximum allowed ratio sigma_eff / sigma_raw. Default 2.0.

Returns:

Per-reflection effective sigmas, shape (N,).

Return type:

torch.Tensor

forward(fcalc, use_mask=True, f_sol_override=None)[source]

Forward pass for the ScalerBase module.

Parameters:
  • fcalc (torch.Tensor) – Calculated structure factors. Expected shape (N,), an additional dimension for batch is possible. N should match the full HKL size.

  • use_mask (bool, default True) – Deprecated parameter, kept for backward compatibility.

  • f_sol_override (torch.Tensor, optional) – Pre-computed raw solvent structure factors. When provided, these replace the internally-cached _f_sol_raw. The scaler’s k_sol / B_sol / phase damping is still applied. This is used by CollectionScaler to supply mixed (fraction-weighted) solvent contributions.

Returns:

Scaled structure factors of same shape as input.

Return type:

torch.Tensor

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing the complete state of the ScalerBase.

This includes:

  • All registered buffers and parameters (via parent class)

  • Scaler-specific metadata (nbins, etc.)

  • Solvent model state (if set)

Note: Data reference is NOT saved (managed separately).

Parameters:
  • destination (dict, optional) – Optional dict to populate.

  • prefix (str, default '') – Prefix for parameter names.

  • keep_vars (bool, default False) – Whether to keep variables in computational graph.

Returns:

Complete state dictionary.

Return type:

dict

load_state_dict(state_dict, strict=True)[source]

Load the ScalerBase state from a dictionary.

Note: This assumes data is already set via __init__ or set_data().

Parameters:
  • state_dict (dict) – Dictionary containing scaler state.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

save_state(path)[source]

Save the complete state of the scaler to a file.

Parameters:

path (str) – Path to save the state dictionary to.

load_state(path, strict=True)[source]

Load the complete state of the scaler from a file.

Parameters:
  • path (str) – Path to load the state dictionary from.

  • strict (bool, default True) – Whether to strictly enforce that keys match.

class torchref.SolventModel(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]

Bases: DeviceMixin, DebugMixin, Module

SolventModel to compute solvent contribution to structure factors using Phenix-like approach.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    solvent = SolventModel()  # Creates empty shell
    solvent.load_state_dict(torch.load('solvent.pt'))
    
  2. Full initialization with model:

    solvent = SolventModel(model, k_solvent=0.35, b_solvent=46.0)
    
model

The atomic model for structure factor calculations.

Type:

ModelFT or None

device

Device for tensor operations.

Type:

torch.device

verbose

Verbosity level.

Type:

int

float_type

Floating point data type.

Type:

torch.dtype

solvent_radius

Probe radius in Angstroms for dilation.

Type:

float

erosion_radius

Radius in Angstroms for erosion step.

Type:

float

optimize_phase

Whether to optimize phase offset parameter.

Type:

bool

log_k_solvent

Log of solvent scattering scale factor.

Type:

torch.nn.Parameter

b_solvent

Solvent B-factor.

Type:

torch.nn.Parameter

phase_offset

Phase offset in radians.

Type:

torch.nn.Parameter or buffer

__init__(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]

Initialize SolventModel.

If model is provided, fully initializes the solvent model. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • model (ModelFT, optional) – The atomic model used for structure factor calculations (optional for empty init).

  • radius (float, default 1.1) – Probe radius in Angstroms for dilation (water radius).

  • k_solvent (float, default 1.1) – Solvent scattering scale factor.

  • b_solvent (float, default 50.0) – Solvent B-factor.

  • erosion_radius (float, default 0.9) – Radius in Angstroms for erosion step.

  • transition (float, optional) – Gaussian smoothing sigma for mask edges (default: radius/4 in voxels). Avoids ringing artifacts.

  • optimize_phase (bool, default True) – Whether to optimize phase offset parameter.

  • initial_phase_offset (float, default 0.0) – Initial phase offset in radians.

  • verbose (int, default 1) – Verbosity level.

  • float_type (torch.dtype, default torch.float32) – Floating point data type.

  • device (torch.device, default: configured device.current) – Device for tensor operations.

get_solvent_mask()[source]

Generate solvent mask following Phenix’s three-step process.

Step 1 (dilation): classify voxels around each atom as protein

(inside VdW), boundary (between VdW and VdW+solvent_radius), or bulk solvent (further out). Built in chunks over atoms so peak memory is O(atom_chunk_size × N_box_voxels) rather than O(N_atoms × N_box_voxels) — critical because for typical macromolecule + grid combinations the dense form is multi-GB.

Step 2 (symmetry expansion): transform the sparse ASU protein /

boundary voxel indices through each symop and scatter into the P1 grid masks.

Step 3 (erosion): a boundary voxel becomes solvent if any voxel

within erosion_radius of it is bulk solvent. Implemented as a single F.conv3d with a precomputed spherical kernel and circular padding — replaces the previous Python-loop + per-voxel-neighbourhood expansion that itself ran out of memory on chunks of 10^6 boundary voxels.

Returns:

Solvent mask (boolean) where True = solvent.

Return type:

torch.Tensor

update_solvent()[source]
smooth_solvent_mask()[source]
get_rec_solvent(hkl)[source]

Compute solvent structure factors.

Uses the standard crystallographic approach: compute SFs from the solvent mask. The mask represents regions where bulk solvent scattering occurs.

Parameters:

hkl (torch.Tensor) – Miller indices.

Returns:

Complex solvent structure factors.

Return type:

torch.Tensor

forward(hkl, update_fsol=False, F_protein=None)[source]

Compute solvent contribution to structure factors at given HKL.

This method is differentiable with respect to k_solvent, b_solvent, and phase_offset parameters.

The solvent model:

  1. Takes the binary solvent mask

  2. Smooths it with Gaussian filter (σ=1.5 voxels) to create soft edges

  3. Computes structure factors via FFT

  4. Applies B-factor damping: exp(-B * s²) where s = sin(θ)/λ

  5. If optimize_phase=True and F_protein provided: blends mask phases with protein phases phase_offset controls the blend: 0=use mask phases, ±π=use protein phases

  6. Scales by k_solvent

Parameters:
  • hkl (torch.Tensor) – Miller indices, shape (N, 3).

  • update_fsol (bool, default False) – Whether to update solvent structure factors.

  • F_protein (torch.Tensor, optional) – Protein structure factors, used for phase blending.

Returns:

Complex solvent structure factors, shape (N,).

Return type:

torch.Tensor

parameters()[source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
class torchref.Cell(data, *, dtype=torch.float32, device=device(type='cpu'), requires_grad=False)[source]

Bases: DeviceMixin

Dataclass for crystallographic unit cells with cached derived quantities.

Stores 6 parameters: [a, b, c, alpha, beta, gamma] - a, b, c: cell lengths in Angstroms - alpha, beta, gamma: cell angles in degrees

Derived quantities (fractional_matrix, volume, etc.) are computed on first access and cached. The cache is cleared when the cell is moved to a different device or dtype.

Examples

>>> cell = Cell([50, 60, 70, 90, 90, 90])
>>> cell.volume  # Computed and cached
tensor(210000.)
>>> cell_gpu = cell.to('cuda')  # Move to GPU (returns new Cell)
>>> cell_gpu.device.type
'cuda'
__init__(data, *, dtype=torch.float32, device=device(type='cpu'), requires_grad=False)[source]

Create a new Cell.

Parameters:
  • data (array-like) – Unit cell parameters [a, b, c, alpha, beta, gamma]. Can be a list, numpy array, or torch tensor.

  • dtype (torch.dtype, optional) – Desired data type. Defaults to the configured dtypes.float.

  • device (torch.device or str, optional) – Desired device. Defaults to the configured device.current.

  • requires_grad (bool, optional) – Whether to track gradients. Defaults to False.

Raises:

ValueError – If data does not have exactly 6 elements.

reset_cache()[source]

Clear cached derived quantities (fractional matrix, volume, etc.).

detach()[source]

Return a new Cell with detached tensor (no gradient tracking).

Returns:

New Cell with detached data.

Return type:

Cell

clone()[source]

Return a new Cell with cloned tensor data.

Returns:

New Cell with cloned data.

Return type:

Cell

property device: device

Return the device of the underlying tensor.

property dtype: dtype

Return the dtype of the underlying tensor.

property data: Tensor

Return the underlying tensor (for buffer registration).

property requires_grad: bool

Return whether gradients are tracked.

property a: Tensor

Cell length a in Angstroms.

property b: Tensor

Cell length b in Angstroms.

property c: Tensor

Cell length c in Angstroms.

property alpha: Tensor

Cell angle alpha in degrees.

property beta: Tensor

Cell angle beta in degrees.

property gamma: Tensor

Cell angle gamma in degrees.

property fractional_matrix: Tensor

Orthogonalization matrix B (fractional -> Cartesian).

Returns the 3x3 matrix B such that: cart = frac @ B.T

Returns:

Shape (3, 3) orthogonalization matrix.

Return type:

torch.Tensor

property inv_fractional_matrix: Tensor

Fractionalization matrix B^-1 (Cartesian -> fractional).

Returns the 3x3 matrix B^-1 such that: frac = cart @ B^-1.T

Returns:

Shape (3, 3) fractionalization matrix.

Return type:

torch.Tensor

property volume: Tensor

Unit cell volume in cubic Angstroms.

Returns:

Scalar tensor with the cell volume.

Return type:

torch.Tensor

property reciprocal_basis_matrix: Tensor

Reciprocal basis matrix with [a*, b*, c*] as rows.

Returns:

Shape (3, 3) matrix where rows are the reciprocal basis vectors.

Return type:

torch.Tensor

compute_grid_size(max_res, oversampling=3.0)[source]

Compute minimum grid dimensions for a given resolution.

Uses Shannon-Nyquist sampling criterion to determine the minimum number of grid points needed along each axis.

Parameters:
  • max_res (float) – Maximum resolution in Angstroms.

  • oversampling (float, optional) – Oversampling factor relative to max_res. Default is 3.0 (standard for crystallographic calculations).

Returns:

Minimum grid dimensions (nx, ny, nz).

Return type:

tuple of int

Examples

>>> cell = Cell([50, 60, 70, 90, 90, 90])
>>> cell.compute_grid_size(2.0)
(75, 90, 105)
tolist()[source]

Convert Cell parameters to a standard Python list.

Returns:

List of cell parameters [a, b, c, alpha, beta, gamma].

Return type:

list

fractional_to_cartesian(frac_coords)[source]

Convert fractional coordinates to Cartesian coordinates.

Parameters:

frac_coords (torch.Tensor) – Tensor of fractional coordinates, shape (…, 3).

Returns:

Tensor of Cartesian coordinates, shape (…, 3).

Return type:

torch.Tensor

cartesian_to_fractional(cart_coords)[source]

Convert Cartesian coordinates to fractional coordinates.

Parameters:

cart_coords (torch.Tensor) – Tensor of Cartesian coordinates, shape (…, 3).

Returns:

Tensor of fractional coordinates, shape (…, 3).

Return type:

torch.Tensor

__repr__()[source]

Return string representation.

__getitem__(idx)[source]

Allow indexing like cell[0] for cell length a.

__len__()[source]

Return 6 (number of cell parameters).

class torchref.SpaceGroup(space_group=None, dtype=torch.float32, device=device(type='cpu'))[source]

Bases: DeviceMixin, DebugMixin, Module

Unified space group handler for crystallographic symmetry operations.

This class combines space group normalization with symmetry operations, providing a single interface for: - Normalizing input (string, int, gemmi.SpaceGroup) in the constructor - Storing symmetry matrices and translations as PyTorch buffers - Applying symmetry operations to fractional coordinates - Grid size utilities for symmetry-compatible grids

Parameters:
  • space_group (str, int, gemmi.SpaceGroup, SpaceGroup, or None) – Space group specification. Accepts: - Hermann-Mauguin symbol (e.g., ‘P21’, ‘P 21 21 21’) - Space group number (1-230) - gemmi.SpaceGroup object - Another SpaceGroup instance - None (defaults to P1)

  • dtype (torch.dtype, default torch.float64) – Data type for rotation matrices and translations.

  • device (torch.device, default: configured device.current) – Device for computation.

matrices

Rotation matrices for all symmetry operations (registered buffer).

Type:

torch.Tensor, shape (n_ops, 3, 3)

translations

Translation vectors for all symmetry operations (registered buffer).

Type:

torch.Tensor, shape (n_ops, 3)

n_ops

Number of symmetry operations.

Type:

int

Examples

# Create from various inputs
sg = SpaceGroup('P21')
sg = SpaceGroup('P 21')      # Same result
sg = SpaceGroup(4)           # P21 by number
sg = SpaceGroup(None)        # Returns P1

# Access properties
print(sg.name)          # 'P21' (short name)
print(sg.hm)            # 'P 21' (Hermann-Mauguin with spaces)
print(sg.number)        # 4

# Apply symmetry operations
coords = torch.tensor([[0.1, 0.2, 0.3]])
transformed = sg(coords)  # Apply all symmetry operations

# Grid utilities
req = sg.get_grid_requirements()
suggested = sg.suggest_grid_size((131, 163, 148))
__init__(space_group=None, dtype=torch.float32, device=device(type='cpu'))[source]
property n_ops: int

Number of symmetry operations.

property name: str

Short space group name (e.g., ‘P21’).

property hm: str

Hermann-Mauguin notation with spaces (e.g., ‘P 21’).

property xhm: str

Extended Hermann-Mauguin notation.

property number: int

Space group number (1-230).

property gemmi: SpaceGroup

Access a gemmi.SpaceGroup object (created on demand, not stored).

property point_group: str

Point group symbol (e.g., ‘222’, ‘mmm’).

property crystal_system: str

Crystal system name.

property centrosymmetric: bool

True if space group has inversion center.

property dtype: dtype

Data type used for matrices.

property device: device

Device for matrices.

property spacegroup: SpaceGroup

Alias for gemmi property (backward compatibility).

property space_group: SpaceGroup

Alias for gemmi property (backward compatibility).

property space_group_name: str

Alias for name property (backward compatibility).

property space_group_number: int

Alias for number property (backward compatibility).

short_name()[source]

Get short space group name.

operations()[source]

Get symmetry operations (creates temporary gemmi object on demand).

apply(xyz_fractional, apply_translation=True)[source]

Apply symmetry operations to fractional coordinates (rotation + translation).

For real space coordinates, applies the full symmetry operation: x’ = R·x + t

Parameters:

xyz_fractional (torch.Tensor) – Input tensor of shape (N, 3) representing fractional coordinates.

Returns:

Transformed coordinates of shape (N, 3, ops) where ops is the number of symmetry operations.

Return type:

torch.Tensor

See also

apply_to_hkl

For reciprocal space (Miller indices), rotation only.

apply_to_hkl(hkl)[source]

Apply symmetry operations to Miller indices (rotation only, no translation).

For reciprocal space, only the rotational part of symmetry operations applies to Miller indices: h’ = R·h. The translation vector affects the phase of structure factors, not the indices themselves.

Parameters:

hkl (torch.Tensor) – Input tensor of shape (N, 3) representing Miller indices.

Returns:

Transformed Miller indices of shape (N, 3, ops) where ops is the number of symmetry operations.

Return type:

torch.Tensor

See also

apply

For real space coordinates (rotation + translation).

expand_coords_to_P1(xyz_fractional)[source]

Expand fractional coordinates by applying all symmetry operations.

Parameters:

xyz_fractional (torch.Tensor) – Input tensor of shape (N, 3) representing fractional coordinates.

Returns:

Expanded coordinates of shape (N * ops, 3).

Return type:

torch.Tensor

forward(xyz_fractional)[source]

Forward pass applies symmetry operations.

get_grid_requirements()[source]

Analyze symmetry operations to determine grid size requirements.

Returns:

{‘nx_mod’: int, ‘ny_mod’: int, ‘nz_mod’: int} Required divisibility for each axis.

Return type:

dict

Examples

sg = SpaceGroup('P21')
req = sg.get_grid_requirements()
print(req)  # {'nx_mod': 1, 'ny_mod': 2, 'nz_mod': 1}
check_grid_compatibility(grid_shape)[source]

Check if a grid size is compatible with the symmetry operations.

Parameters:

grid_shape (tuple of int) – (nx, ny, nz) grid dimensions.

Returns:

Dictionary with keys: - ‘compatible’ : bool - True if grid satisfies all requirements - ‘symmetry_compatible’ : bool - True if grid satisfies symmetry - ‘fft_friendly’ : bool - True if all dimensions are FFT-friendly - ‘can_use_direct_indexing’ : bool - True if no interpolation needed - ‘issues’ : list of str - Descriptions of incompatibilities - ‘requirements’ : dict - Required divisibility

Return type:

dict

Examples

sg = SpaceGroup('P21')
result = sg.check_grid_compatibility((131, 163, 148))
print(result['compatible'])  # False
print(result['issues'])  # ['ny=163 not divisible by 2']
suggest_grid_size(min_grid_shape, make_fft_friendly=True)[source]

Suggest an optimal grid size that satisfies symmetry requirements.

Parameters:
  • min_grid_shape (tuple of int) – Minimum (nx, ny, nz) grid dimensions.

  • make_fft_friendly (bool, default True) – If True, ensures result has only factors of 2, 3, 5.

Returns:

Suggested grid dimensions (nx, ny, nz).

Return type:

tuple of int

Examples

sg = SpaceGroup('P21')
suggested = sg.suggest_grid_size((131, 163, 148))
print(suggested)  # (135, 164, 150) or similar
__hash__()[source]

Hash based on space group number.

__eq__(other)[source]

Equality based on space group number.

copy()[source]

Create a deep copy of this SpaceGroup.

Returns:

A new SpaceGroup instance with cloned buffers.

Return type:

SpaceGroup

class torchref.Map(data, model, gridsize=None, map_type='2mFo-DFc', device=None)[source]

Bases: DeviceMixin

Crystallographic electron density map.

Parameters:
  • data (ReflectionData) – Observed reflection data with amplitudes, hkl, cell, and spacegroup.

  • model (ModelFT) – Model for computing Fcalc (structure factors).

  • gridsize (tuple of int, optional) – Grid dimensions (nx, ny, nz). If None, determined automatically from cell parameters and resolution.

  • map_type (str, optional) – Type of map to compute. One of "2mFo-DFc" or "Fcalc". Default is "2mFo-DFc".

VALID_MAP_TYPES = ('2mFo-DFc', 'Fcalc')
__init__(data, model, gridsize=None, map_type='2mFo-DFc', device=None)[source]
reset_cache()[source]

Invalidate the cached map tensor; recomputed on next access.

property map_data: Tensor | None

The computed 3D real-space map, or None if not yet calculated.

calculate()[source]

Compute the electron density map.

Returns:

3D real-space map tensor.

Return type:

torch.Tensor

write(filepath)[source]

Write the map to a CCP4 file.

Automatically computes the map if it hasn’t been calculated yet.

Parameters:

filepath (str) – Output CCP4 map file path.

Returns:

1 on success.

Return type:

int

class torchref.DifferenceMap(data, data_reference, model, gridsize=None, device=None)[source]

Bases: Map

Isomorphous difference map between two datasets.

Scales both datasets to a common reference using DatasetCollection, then computes difference Fourier coefficients: DF * exp(i * phi_calc) where DF = F_data - F_reference.

Parameters:
  • data (ReflectionData) – Reflection data for the perturbed state (e.g., light, derivative).

  • data_reference (ReflectionData) – Reflection data for the reference state (e.g., dark, native).

  • model (ModelFT) – Model for computing phases.

  • gridsize (tuple of int, optional) – Grid dimensions (nx, ny, nz). If None, determined automatically.

__init__(data, data_reference, model, gridsize=None, device=None)[source]
calculate()[source]

Compute the isomorphous difference map.

Returns:

3D real-space difference map tensor.

Return type:

torch.Tensor

class torchref.DeviceMixin[source]

Bases: object

Unified device/dtype movement.

Inherit alongside nn.Module (place before nn.Module in the MRO):

class Foo(DeviceMixin, nn.Module):
    ...

Or use on a plain Python class / dataclass:

@dataclass
class Bar(DeviceMixin):
    data: torch.Tensor

All of .to(), .cuda(), .cpu(), .float(), .double(), .half() route through _apply(), which:

  1. invokes nn.Module._apply when applicable so parameters, buffers and child modules are moved by the standard PyTorch path,

  2. walks self.__dict__ to pick up plain tensor attributes, nested containers and non-Module sub-objects,

  3. calls reset_forward_cache() and reset_cache() if either is defined.

to(*args, **kwargs)[source]
cuda(device=None)[source]
cpu()[source]

Subpackages

Submodules