"""
Reflection data container for crystallographic datasets.
This module provides the ReflectionData class for handling single-crystal
reflection data including Miller indices, structure factor amplitudes,
intensities, and R-free flags.
"""
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from torchref.config import dtypes, get_default_device
from torch.nn import Parameter
from torchref.symmetry import SpaceGroup
import numpy as np
import pandas as pd
import torch
from torchref.io import cif, mtz
from torchref.io.datasets.base import CrystalDataset
from torchref.base import math_torch
from torchref.base.french_wilson import FrenchWilson
from torchref.symmetry import Cell
from torchref.utils.debug_utils import DebugMixin
if TYPE_CHECKING:
from torchref.model.model_ft import ModelFT
# Suppress PyTorch MaskedTensor prototype warnings globally
# MaskedTensor is stable enough for our use case (aggregations, element-wise ops)
warnings.filterwarnings(
"ignore", message=".*MaskedTensors is in prototype stage.*", category=UserWarning
)
if TYPE_CHECKING:
from torch.masked import MaskedTensor
[docs]
@dataclass
class ReflectionData(CrystalDataset, DebugMixin):
"""
Container for crystallographic reflection data.
This class handles loading, processing, and accessing reflection data
including Miller indices, structure factor amplitudes, intensities,
and R-free flags. All data is stored as PyTorch tensors for GPU
acceleration.
Parameters
----------
verbose : int, optional
Verbosity level for logging (0=silent, 1=normal, 2=debug). Default is 1.
device : str, optional
Device to store tensors on ('cpu', 'cuda', 'cuda:0', etc.). Defaults to the configured device.current.
Attributes
----------
hkl : torch.Tensor
Miller indices of shape (N, 3), dtype int32.
F : torch.Tensor
Structure factor amplitudes of shape (N,), dtype float32.
F_sigma : torch.Tensor
Amplitude uncertainties of shape (N,), dtype float32.
I : torch.Tensor
Intensities of shape (N,), dtype float32.
I_sigma : torch.Tensor
Intensity uncertainties of shape (N,), dtype float32.
rfree_flags : torch.Tensor
R-free test set flags of shape (N,), dtype bool.
cell : torch.Tensor
Unit cell parameters [a, b, c, alpha, beta, gamma].
spacegroup : str
Space group symbol.
resolution : torch.Tensor
Resolution per reflection in Ångströms of shape (N,).
wilson_b : float
Overall Wilson B-factor in Ų.
Examples
--------
Load reflection data from an MTZ file::
data = ReflectionData(verbose=1, device='cuda')
data.load_mtz('data.mtz')
print(f"Loaded {len(data.hkl)} reflections")
print(f"Resolution range: {data.resolution.min():.2f} - {data.resolution.max():.2f} Å")
"""
# Additional fields specific to ReflectionData (beyond CrystalDataset)
# Note: Most fields are inherited from CrystalDataset dataclass
# Cached properties (not serialized)
_centric: Optional[torch.Tensor] = field(default=None, repr=False)
_n_bins: Optional[int] = field(default=None, repr=False)
_FrenchWilson: Optional[FrenchWilson] = field(default=None, repr=False)
# Dynamic fields used by various methods
source: Optional["ReflectionData"] = field(default=None, repr=False)
dataset: Optional[pd.DataFrame] = field(default=None, repr=False)
last_op: Optional[str] = field(default=None, repr=False)
reader: Optional[Any] = field(default=None, repr=False)
[docs]
def __post_init__(self):
"""
Initialize non-dataclass attributes after dataclass init.
This is called automatically after the dataclass __init__.
"""
# Call parent __post_init__ to initialize masks
super().__post_init__()
self.setup_scale()
self.setup_anisotropy()
def _canonicalize_in_place(self) -> None:
"""Remap HKL to canonical CCP4 ASU form and reorder all data in-place."""
from dataclasses import fields as dc_fields
from torchref.symmetry.reciprocal_symmetry import canonicalize_hkl
if self.hkl is None or self.spacegroup is None:
return
canonical_hkl, phase_shifts, friedel_flags, sort_indices = canonicalize_hkl(
self.hkl, self.spacegroup, include_friedel=True, device=self.device
)
n_refl = len(self.hkl)
# Reorder every per-reflection tensor field in-place
for f in dc_fields(self):
val = getattr(self, f.name)
if isinstance(val, torch.Tensor) and val.shape and val.shape[0] == n_refl:
setattr(self, f.name, val[sort_indices])
# Overwrite HKL with canonical form
self.hkl = canonical_hkl
# Apply phase correction
if self.phase is not None:
self.phase = torch.where(friedel_flags, -self.phase, self.phase) + phase_shifts
# Recalculate resolution from canonical HKL
if self.cell is not None:
self._calculate_resolution()
# Reorder masks
if hasattr(self, "masks") and self.masks is not None:
for name in list(self.masks.keys()):
mask_tensor = self.masks[name]
if mask_tensor is not None:
# Bypass __setitem__ validation (reordering preserves True count)
dict.__setitem__(self.masks, name, mask_tensor[sort_indices])
self.masks._updated = True
# Reorder DataFrame
if hasattr(self, "dataset") and self.dataset is not None:
self.dataset = self.dataset.iloc[sort_indices.cpu().numpy()].copy()
[docs]
def load(self, reader):
"""
Load reflection data using a data reader.
Parameters
----------
reader : callable
Data reader object that returns (data_dict, cell, spacegroup) when called.
Can be MTZ, ReflectionCIFReader, or other compatible reader.
Returns
-------
ReflectionData
Self, for method chaining.
Raises
------
ValueError
If unit cell parameters are missing or no amplitude/intensity data found.
"""
data_dict, cell, spacegroup = reader()
hkl = torch.tensor(
data_dict["HKL"], dtype=dtypes.int, device=self.device, requires_grad=False
)
self.hkl = hkl
if cell is not None:
self.cell = Cell(cell, dtype=dtypes.float, device=self.device)
else:
raise ValueError(
"Unit cell parameters are required in the data and could not be read."
)
if spacegroup is not None:
self.spacegroup = SpaceGroup(spacegroup)
self._calculate_resolution()
if "I" in data_dict:
self.I = torch.tensor(
data_dict["I"],
dtype=dtypes.float,
device=self.device,
requires_grad=False,
)
if "SIGI" in data_dict:
self.I_sigma = torch.tensor(
data_dict["SIGI"],
dtype=dtypes.float,
device=self.device,
requires_grad=False,
)
self.intensity_source = data_dict.get("I_col", "Unknown")
self._FrenchWilson = FrenchWilson(
self.hkl, self.cell.data, self.spacegroup, verbose=self.verbose
)
F, F_sigma = self._FrenchWilson(self.I, self.I_sigma)
self.F = F
self.F_sigma = F_sigma
elif "F" in data_dict:
self.F = torch.tensor(
data_dict["F"],
dtype=dtypes.float,
device=self.device,
requires_grad=False,
)
if "SIGF" in data_dict:
if data_dict["SIGF"] is not None:
self.F_sigma = torch.tensor(
data_dict["SIGF"],
dtype=dtypes.float,
device=self.device,
requires_grad=False,
)
else:
sigF = math_torch.estimate_sigma_F(self.F)
self.F_sigma = sigF
else:
sigF = math_torch.estimate_sigma_F(self.F)
self.F_sigma = sigF
self.amplitude_source = data_dict.get("F_col", "Unknown")
else:
raise ValueError("No amplitude or intensity data found in MTZ file")
if "R-free-flags" in data_dict:
rfree = torch.tensor(
data_dict["R-free-flags"], device=self.device, requires_grad=False
)
flagged = rfree < 0
rfree = rfree.clip(min=0, max=1).to(torch.bool)
self.rfree_flags = rfree
self.masks["flagged_initial"] = ~flagged
else:
flagged = torch.zeros(
len(self.hkl), dtype=torch.bool, device=self.device, requires_grad=False
)
self.masks["flagged_initial"] = ~flagged
self._generate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100)
self._post_load_cleanup()
return self
def _post_load_cleanup(self) -> "ReflectionData":
"""
Run all post-load processing steps on the reflection data.
This is called automatically after ``load()`` and ``from_tensors()``.
It performs:
1. Resolution calculation from HKL + cell
2. Initial flagging mask (marks all reflections as valid if not set)
3. Canonicalization of HKL to CCP4 ASU
4. Sanitization of F (mask NaN/Inf/non-positive)
5. Suspicious sigma detection
Returns
-------
ReflectionData
Self, for method chaining.
"""
if self.resolution is None:
self._calculate_resolution()
if "flagged_initial" not in self.masks:
self.masks["flagged_initial"] = torch.ones(
len(self.hkl), dtype=torch.bool, device=self.device
)
self._canonicalize_in_place()
self.sanitize_F()
self.flag_suspicious_sigma()
return self
[docs]
@classmethod
def from_tensors(
cls,
hkl: torch.Tensor,
F: torch.Tensor,
F_sigma: torch.Tensor,
cell: "Cell",
spacegroup: "SpaceGroup",
rfree_flags: Optional[torch.Tensor] = None,
device=get_default_device(),
verbose: int = 1,
) -> "ReflectionData":
"""
Construct ReflectionData directly from tensors.
Parameters
----------
hkl : torch.Tensor
Miller indices of shape (N, 3).
F : torch.Tensor
Structure factor amplitudes of shape (N,).
F_sigma : torch.Tensor
Amplitude uncertainties of shape (N,).
cell : Cell
Unit cell parameters.
spacegroup : SpaceGroup
Space group.
rfree_flags : torch.Tensor, optional
R-free flags of shape (N,), dtype bool. If None, flags are
generated automatically (2% free fraction).
device : str, optional
Device for tensors. Defaults to the configured device.current.
verbose : int, optional
Verbosity level. Default is 1.
Returns
-------
ReflectionData
Fully initialized reflection data with all cleanup applied.
"""
data = cls(device=device, verbose=verbose)
data.hkl = hkl.to(device=data.device)
data.F = F.to(device=data.device)
data.F_sigma = F_sigma.to(device=data.device)
data.cell = cell.to(device=data.device) if hasattr(cell, 'to') else Cell(cell, device=data.device)
data.spacegroup = spacegroup if isinstance(spacegroup, SpaceGroup) else SpaceGroup(spacegroup)
if rfree_flags is not None:
data.rfree_flags = rfree_flags.to(device=data.device, dtype=torch.bool)
else:
data._calculate_resolution()
data._generate_rfree_flags(free_fraction=0.02, n_bins=20, min_per_bin=100)
data._post_load_cleanup()
return data
[docs]
def load_mtz(
self, path: str, column_names: Optional[dict] = None
) -> "ReflectionData":
"""
Load reflection data from MTZ file.
Parameters
----------
path : str
Path to MTZ file.
column_names : dict, optional
Explicit column name mapping to override automatic detection.
Supported keys: ``"F"``, ``"SIGF"``, ``"I"``, ``"SIGI"``.
Example: ``{"F": "DFo", "SIGF": "sig_DFo"}``.
Returns
-------
ReflectionData
Self, for method chaining.
"""
reader = mtz.MTZReader(
verbose=self.verbose, column_names=column_names
).read(path)
return self.load(reader)
[docs]
def load_cif(self, path: str, data_block: Optional[str] = None) -> "ReflectionData":
"""
Load reflection data from CIF file.
Parameters
----------
path : str
Path to CIF file.
data_block : str, optional
Specific data block name to read (e.g., 'r1vlmsf'). If None, reads
the first data block. Useful for multi-dataset CIF files.
Returns
-------
ReflectionData
Self, for method chaining.
"""
self.reader = cif.ReflectionCIFReader(
path, verbose=self.verbose, data_block=data_block
)
return self.load(self.reader)
[docs]
@staticmethod
def list_cif_data_blocks(path: str) -> List[str]:
"""
List all data blocks available in a CIF file without loading data.
Useful for multi-dataset CIF files to inspect available blocks
before loading a specific one.
Parameters
----------
path : str
Path to CIF file.
Returns
-------
list of str
Names of all data blocks in the CIF file.
Examples
--------
List and load a specific data block::
blocks = ReflectionData.list_cif_data_blocks('1VLM-sf.cif')
print(blocks) # ['r1vlmsf', 'r1vlmAsf', 'r1vlmBsf', ...]
data = ReflectionData().load_cif('1VLM-sf.cif', data_block=blocks[1])
"""
return cif.list_data_blocks(path)
def _generate_rfree_flags(
self,
free_fraction: float = 0.02,
n_bins: int = 20,
min_per_bin: int = 100,
seed: Optional[int] = None,
) -> None:
"""
Generate R-free flags with resolution-stratified sampling.
Ensures free reflections are evenly distributed across resolution
shells for unbiased cross-validation.
Parameters
----------
free_fraction : float, optional
Fraction of reflections to mark as free (0.02 = 2%). Default is 0.02.
n_bins : int, optional
Target number of resolution bins. Default is 20.
min_per_bin : int, optional
Minimum reflections per resolution bin. Default is 100.
seed : int, optional
Random seed for reproducibility. Default is None.
Notes
-----
Algorithm:
1. Bin reflections by resolution
2. Ensure bins have at least min_per_bin reflections
3. Randomly select free_fraction from each bin
4. This ensures even distribution across all resolution ranges
Raises
------
ValueError
If resolution information is not available.
"""
if self.resolution is None:
raise ValueError("Resolution information required to generate R-free flags")
print("Generating R-free flags:")
print(f" Target free fraction: {free_fraction*100:.1f}%")
print(f" Target bins: {n_bins}")
print(f" Minimum per bin: {min_per_bin} reflections")
# Set random seed for reproducibility
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
n_refl = len(self.resolution)
# Create resolution bins
bin_indices, actual_n_bins = self.get_bins(
n_bins=n_bins, min_per_bin=min_per_bin
)
print(f" Created {actual_n_bins} resolution bins")
# Initialize all flags as work set (1)
flags = torch.ones(n_refl, dtype=dtypes.int, device=self.device, requires_grad=False)
# Sample free reflections from each bin
total_free = 0
for bin_idx in range(actual_n_bins):
bin_mask = bin_indices == bin_idx
bin_size = bin_mask.sum().item()
if bin_size == 0:
continue
# Number of free reflections in this bin
# Ensure at least 1, but respect the free_fraction
n_free_in_bin = max(1, int(bin_size * free_fraction))
# Get indices of reflections in this bin
bin_refl_indices = torch.where(bin_mask)[0]
# Randomly select free reflections
perm = torch.randperm(bin_size)[:n_free_in_bin]
free_indices = bin_refl_indices[perm]
# Mark as free (0)
flags[free_indices] = 0
total_free += n_free_in_bin
# Move to correct device and assign
flags_tensor = flags.to(dtype=dtypes.int, device=self.device)
self.rfree_flags = flags_tensor
self.rfree_source = "Generated (resolution-binned)"
n_free = (flags == 0).sum().item()
n_work = (flags != 0).sum().item()
free_pct = 100.0 * n_free / n_refl
print(
f" ✓ Generated flags: {n_free} free ({free_pct:.1f}%), {n_work} work ({100-free_pct:.1f}%)"
)
print(" Flags are resolution-binned for unbiased validation")
[docs]
def get_bins(
self, n_bins: int = 20, min_per_bin: int = 100
) -> Tuple[torch.Tensor, int]:
"""
Create resolution bins with approximately equal reflection counts.
Parameters
----------
n_bins : int, optional
Target number of resolution bins. Default is 20.
min_per_bin : int, optional
Minimum reflections per bin. Default is 100.
Returns
-------
bin_indices : torch.Tensor
Tensor of shape (N,) with bin index for each reflection.
n_bins : int
Actual number of bins created (may be less than target for small datasets).
"""
n_refl = len(self.resolution)
valid_mask = self.masks()
total_valid = valid_mask.sum().item()
# Calculate how many bins we can actually create given min_per_bin constraint
max_possible_bins = max(1, total_valid // min_per_bin)
actual_n_bins = min(n_bins, max_possible_bins)
if actual_n_bins < n_bins and self.verbose > 0:
print(
f" Note: Adjusted bins from {n_bins} to {actual_n_bins} (min {min_per_bin} refl/bin)"
)
# Sort reflections by resolution
_, sort_indices = torch.sort(self.resolution)
# Create bins with approximately equal number of VALID reflections
bin_indices = torch.zeros(n_refl, dtype=dtypes.int, device=self.device)
reflections_per_bin = total_valid // actual_n_bins
# Get the valid mask in sorted order
valid_mask_sorted = valid_mask[sort_indices]
# Cumulative sum of valid reflections in sorted order
cumsum_valid = torch.cumsum(valid_mask_sorted.to(dtypes.int), dim=0)
# Create bin edges based on cumulative count of valid reflections
# Each bin should contain approximately reflections_per_bin valid reflections
bin_edges = [0]
for bin_idx in range(1, actual_n_bins):
target_count = bin_idx * reflections_per_bin
# Find first index where cumsum >= target_count
edge_candidates = torch.where(cumsum_valid >= target_count)[0]
if len(edge_candidates) > 0:
bin_edges.append(edge_candidates[0].item())
bin_edges.append(n_refl)
# Assign bin indices to sorted reflections, then map back to original order
for bin_idx in range(len(bin_edges) - 1):
start, end = bin_edges[bin_idx], bin_edges[bin_idx + 1]
bin_indices[sort_indices[start:end]] = bin_idx
if self.verbose > 1:
# Print bin statistics
print(" Resolution bins:")
for bin_idx in range(min(actual_n_bins, 20)): # Show all bins (up to 20)
bin_mask = bin_indices == bin_idx
if bin_mask.sum() > 0:
valid_reflexes = bin_mask & valid_mask
bin_res = self.resolution[bin_mask]
print(
f" Bin {bin_idx:2d}: {valid_reflexes.sum():6d} valid refl, "
f"resolution {bin_res.min():.2f}-{bin_res.max():.2f} Å"
)
if actual_n_bins > 20:
print(f" ... ({actual_n_bins - 20} more bins)")
self.bin_indices = bin_indices
self._n_bins = actual_n_bins
return bin_indices, actual_n_bins
[docs]
def mean_res_per_bin(self) -> torch.Tensor:
"""
Calculate mean resolution for each bin.
Returns
-------
torch.Tensor
Mean resolution for each bin in Ångströms.
Raises
------
ValueError
If bins have not been created yet.
"""
if self.bin_indices is None or self.resolution is None:
raise ValueError("Bins have not been created yet")
mean_resolutions = torch.zeros(
self._n_bins, dtype=dtypes.float, device=self.device
)
count_per_bin = torch.zeros(self._n_bins, dtype=dtypes.int, device=self.device)
mask = self.masks()
mean_resolutions = torch.scatter_add(
mean_resolutions,
0,
self.bin_indices[mask].to(torch.int64),
self.resolution[mask],
)
count_per_bin = torch.scatter_add(
count_per_bin,
0,
self.bin_indices[mask].to(torch.int64),
torch.ones_like(self.resolution[mask], dtype=dtypes.int),
)
mean_resolutions = mean_resolutions / count_per_bin.clamp(min=1).float()
return mean_resolutions
[docs]
def mean_F_per_bin(self) -> torch.Tensor:
"""
Calculate mean structure factor amplitude per resolution bin.
Returns
-------
torch.Tensor
Mean F per bin of shape (n_bins,).
Raises
------
ValueError
If bins have not been created yet.
"""
if self.bin_indices is None:
self.get_bins()
if self.F is None:
raise ValueError("No amplitude data loaded")
mean_F = torch.zeros(self._n_bins, dtype=dtypes.float, device=self.device)
count_per_bin = torch.zeros(self._n_bins, dtype=dtypes.int, device=self.device)
mask = self.masks()
mean_F = torch.scatter_add(
mean_F, 0, self.bin_indices[mask].to(torch.int64), self.F[mask]
)
count_per_bin = torch.scatter_add(
count_per_bin,
0,
self.bin_indices[mask].to(torch.int64),
torch.ones_like(self.F[mask], dtype=dtypes.int),
)
mean_F = mean_F / count_per_bin.clamp(min=1).float()
return mean_F
[docs]
def mean_sigma_per_bin(self) -> Optional[torch.Tensor]:
"""
Calculate mean structure factor uncertainty per resolution bin.
Returns
-------
torch.Tensor or None
Mean sigma_F per bin of shape (n_bins,), or None if no uncertainties.
Raises
------
ValueError
If bins have not been created yet.
"""
if self.bin_indices is None:
self.get_bins()
if self.F is None:
raise ValueError("No amplitude data loaded")
mean_sigma = torch.zeros(self._n_bins, dtype=dtypes.float, device=self.device)
count_per_bin = torch.zeros(self._n_bins, dtype=dtypes.int, device=self.device)
mask = self.masks()
mean_sigma = torch.scatter_add(
mean_sigma, 0, self.bin_indices[mask].to(torch.int64), self.F_sigma[mask]
)
count_per_bin = torch.scatter_add(
count_per_bin,
0,
self.bin_indices[mask].to(torch.int64),
torch.ones_like(self.F_sigma[mask], dtype=dtypes.int),
)
mean_sigma = mean_sigma / count_per_bin.clamp(min=1).float()
return mean_sigma
[docs]
def regenerate_rfree_flags(
self,
free_fraction: float = 0.02,
n_bins: int = 20,
min_per_bin: int = 100,
seed: Optional[int] = None,
force: bool = False,
) -> None:
"""
Regenerate R-free flags with resolution-stratified sampling.
Parameters
----------
free_fraction : float, optional
Fraction of reflections to mark as free. Default is 0.02 (2%).
n_bins : int, optional
Target number of resolution bins. Default is 20.
min_per_bin : int, optional
Minimum reflections per resolution bin. Default is 100.
seed : int, optional
Random seed for reproducibility. Default is None.
force : bool, optional
If True, regenerate even if flags already exist. Default is False.
Examples
--------
Generate R-free flags::
# Generate 2% free reflections with reproducible seed
data.regenerate_rfree_flags(free_fraction=0.02, n_bins=20, seed=42)
# Generate 5% free with 10 bins, overwriting existing
data.regenerate_rfree_flags(free_fraction=0.05, n_bins=10, force=True)
"""
if self.rfree_flags is not None and not force:
print("⚠️ WARNING: R-free flags already exist!")
print(f" Current source: {self.rfree_source}")
print(" Use force=True to overwrite existing flags")
return
if self.rfree_flags is not None and force:
print("⚠️ WARNING: Overwriting existing R-free flags")
print(f" Old source: {self.rfree_source}")
self._generate_rfree_flags(
free_fraction=free_fraction,
n_bins=n_bins,
min_per_bin=min_per_bin,
seed=seed,
)
def _calculate_resolution(self) -> None:
"""
Calculate resolution for each reflection from Miller indices.
Sets the `resolution` buffer with d-spacing values in Ångströms.
Raises
------
ValueError
If Miller indices or unit cell parameters are missing.
"""
if self.hkl is None:
raise ValueError(
"Miller indices (hkl) are required to calculate resolution"
)
if self.cell is None:
raise ValueError(
"Unit cell parameters are required to calculate resolution"
)
s = math_torch.get_scattering_vectors(self.hkl, self.cell.data)
resolution = 1.0 / torch.linalg.norm(s, axis=1)
self.resolution = resolution
def _calculate_wilson_b(self, n_bins: int = 30) -> None:
"""
Calculate Wilson B-factors from structure factor amplitudes.
Fits a two-component model separating structure and solvent contributions:
<F²> ∝ A_struct * exp(-2*B_struct*s²) + A_sol * exp(-2*B_sol*s²)
The fitting proceeds in three stages:
1. Estimate B_structure from high-resolution data (d < 3.5 Å) where solvent is negligible
2. Estimate B_solvent from low-resolution data (d > 6 Å) where solvent dominates
3. Refine both together with two-exponential fit across all data
Args:
n_bins: Number of resolution bins for averaging (default: 30)
Sets:
self.wilson_b: Overall Wilson B-factor (weighted average) in Ų
self.wilson_b_structure: Structure B-factor from high-res in Ų
self.wilson_b_solvent: Solvent B-factor from low-res in Ų
self.wilson_k_sol: Relative solvent contribution (0-1)
"""
if self.F is None or self.resolution is None:
return
# Get valid reflections
F = self.F
d = self.resolution
valid = torch.isfinite(F) & (F > 0) & torch.isfinite(d)
if valid.sum() < 100:
if self.verbose > 0:
print(
f" Wilson B: insufficient data ({valid.sum()} reflections), skipping"
)
return
F_valid = F[valid]
d_valid = d[valid]
# Calculate s² = 1/(4d²)
s_sq = 1.0 / (4.0 * d_valid**2)
F_sq = F_valid**2
# Bin the data for noise reduction
s_sq_min, s_sq_max = s_sq.min(), s_sq.max()
bin_edges = torch.linspace(s_sq_min, s_sq_max, n_bins + 1, device=self.device)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
bin_idx = torch.bucketize(s_sq, bin_edges[1:-1])
# Calculate mean F² per bin
bin_sums = torch.zeros(n_bins, device=self.device, dtype=F_sq.dtype)
bin_counts = torch.zeros(n_bins, device=self.device, dtype=F_sq.dtype)
bin_sums.scatter_add_(0, bin_idx, F_sq)
bin_counts.scatter_add_(0, bin_idx, torch.ones_like(F_sq))
valid_bins = bin_counts > 5
if valid_bins.sum() < 5:
if self.verbose > 0:
print(f" Wilson B: insufficient bins ({valid_bins.sum()}), skipping")
return
mean_F_sq = bin_sums[valid_bins] / bin_counts[valid_bins]
s_sq_bins = bin_centers[valid_bins]
# Convert s² back to d-spacing for resolution-based selection
d_bins = 1.0 / (2.0 * torch.sqrt(s_sq_bins))
# Stage 1: Fit high-resolution region (d < 3.5 Å) for structure B
high_res_mask = d_bins < 3.5
B_struct = self._fit_single_wilson(
s_sq_bins, mean_F_sq, high_res_mask, "high-res"
)
# Stage 2: Fit low-resolution region (d > 6 Å) for solvent B
low_res_mask = d_bins > 6.0
B_sol = self._fit_single_wilson(s_sq_bins, mean_F_sq, low_res_mask, "low-res")
# Stage 3: Two-component fit across all data
B_struct_final, B_sol_final, k_sol = self._fit_two_component_wilson(
s_sq_bins, mean_F_sq, B_struct, B_sol
)
# Store results
self.wilson_b_structure = B_struct_final
self.wilson_b_solvent = B_sol_final
self.wilson_k_sol = k_sol
# Overall Wilson B is the structure B (what people usually mean by "Wilson B")
self.wilson_b = B_struct_final
if self.verbose > 0:
print(f" Wilson B-factor (structure): {B_struct_final:.1f} Ų")
print(f" Wilson B-factor (solvent): {B_sol_final:.1f} Ų")
print(f" Solvent fraction (k_sol): {k_sol:.3f}")
def _fit_single_wilson(
self,
s_sq: torch.Tensor,
mean_F_sq: torch.Tensor,
mask: torch.Tensor,
label: str,
) -> float:
"""
Fit single-exponential Wilson plot to selected resolution range.
Args:
s_sq: s² values for bins
mean_F_sq: Mean F² values for bins
mask: Boolean mask selecting which bins to use
label: Label for error messages
Returns:
B-factor from fit (Ų)
"""
if mask.sum() < 3:
# Not enough data, return reasonable default
if self.verbose > 1:
print(
f" Wilson {label}: insufficient bins ({mask.sum()}), using default"
)
return 50.0 if "struct" in label else 200.0
x = s_sq[mask]
y = torch.log(mean_F_sq[mask])
# Linear regression: ln(F²) = const - 2B*s²
x_mean = x.mean()
y_mean = y.mean()
numerator = ((x - x_mean) * (y - y_mean)).sum()
denominator = ((x - x_mean) ** 2).sum()
if denominator < 1e-12:
return 50.0 if "struct" in label else 200.0
slope = numerator / denominator
B = -slope.item() / 2.0
# Sanity bounds
B = max(0.0, min(B, 300.0))
return B
def _fit_two_component_wilson(
self,
s_sq: torch.Tensor,
mean_F_sq: torch.Tensor,
B_struct_init: float,
B_sol_init: float,
n_iter: int = 50,
) -> Tuple[float, float, float]:
"""
Fit two-component Wilson model using iterative refinement.
Model: F² = A_struct * exp(-2*B_struct*s²) + A_sol * exp(-2*B_sol*s²)
Parameterized as:
F² = A * [(1-k)*exp(-2*B_struct*s²) + k*exp(-2*B_sol*s²)]
where k is the relative solvent contribution at s²=0.
Parameters
----------
s_sq : torch.Tensor
s² values for bins.
mean_F_sq : torch.Tensor
Mean F² values for bins.
B_struct_init : float
Initial structure B-factor.
B_sol_init : float
Initial solvent B-factor.
n_iter : int, optional
Number of refinement iterations. Default is 50.
Returns
-------
B_struct : float
Refined structure B-factor.
B_sol : float
Refined solvent B-factor.
k_sol : float
Relative solvent contribution.
"""
# Normalize F² for numerical stability
F_sq_max = mean_F_sq.max()
y = mean_F_sq / F_sq_max
x = s_sq
# Initialize parameters
B_struct = torch.tensor(B_struct_init, device=self.device, dtype=x.dtype)
B_sol = torch.tensor(B_sol_init, device=self.device, dtype=x.dtype)
# Estimate initial k from ratio of low-res to high-res decay
# At low resolution, solvent contributes more
d_from_s = 1.0 / (2.0 * torch.sqrt(x))
low_res_val = y[d_from_s > 5.0].mean() if (d_from_s > 5.0).any() else y[0]
high_res_val = y[d_from_s < 3.0].mean() if (d_from_s < 3.0).any() else y[-1]
# k estimates solvent fraction - if low res is much higher than expected
# from structure alone, there's solvent contribution
struct_decay = torch.exp(-2 * B_struct * x)
expected_low = (
struct_decay[d_from_s > 5.0].mean()
if (d_from_s > 5.0).any()
else struct_decay[0]
)
if expected_low > 1e-6 and low_res_val > expected_low:
k_init = min(0.5, (low_res_val - expected_low).item() / low_res_val.item())
else:
k_init = 0.1
k = torch.tensor(max(0.01, min(0.5, k_init)), device=self.device, dtype=x.dtype)
# Simple gradient descent refinement
lr = 0.1
for _ in range(n_iter):
# Compute model
struct_term = (1 - k) * torch.exp(-2 * B_struct * x)
sol_term = k * torch.exp(-2 * B_sol * x)
model = struct_term + sol_term
# Compute scale factor analytically
A = (y * model).sum() / (model * model).sum()
model_scaled = A * model
# Compute gradients (simplified, using finite differences for robustness)
eps = 0.1
# B_struct gradient
model_plus = A * ((1 - k) * torch.exp(-2 * (B_struct + eps) * x) + sol_term)
model_minus = A * (
(1 - k) * torch.exp(-2 * (B_struct - eps) * x) + sol_term
)
loss_plus = ((y - model_plus) ** 2).sum()
loss_minus = ((y - model_minus) ** 2).sum()
grad_B_struct = (loss_plus - loss_minus) / (2 * eps)
# B_sol gradient
model_plus = A * (struct_term + k * torch.exp(-2 * (B_sol + eps) * x))
model_minus = A * (struct_term + k * torch.exp(-2 * (B_sol - eps) * x))
loss_plus = ((y - model_plus) ** 2).sum()
loss_minus = ((y - model_minus) ** 2).sum()
grad_B_sol = (loss_plus - loss_minus) / (2 * eps)
# k gradient
eps_k = 0.01
k_plus = min(0.9, k + eps_k)
k_minus = max(0.01, k - eps_k)
model_plus = A * (
(1 - k_plus) * torch.exp(-2 * B_struct * x)
+ k_plus * torch.exp(-2 * B_sol * x)
)
model_minus = A * (
(1 - k_minus) * torch.exp(-2 * B_struct * x)
+ k_minus * torch.exp(-2 * B_sol * x)
)
loss_plus = ((y - model_plus) ** 2).sum()
loss_minus = ((y - model_minus) ** 2).sum()
grad_k = (loss_plus - loss_minus) / (2 * eps_k)
# Update parameters
B_struct = B_struct - lr * grad_B_struct
B_sol = B_sol - lr * grad_B_sol
k = k - lr * 0.1 * grad_k # Slower learning rate for k
# Enforce constraints
B_struct = torch.clamp(B_struct, 1.0, 200.0)
B_sol = torch.clamp(B_sol, 50.0, 500.0)
k = torch.clamp(k, 0.01, 0.9)
# Ensure B_sol > B_struct (solvent is more disordered)
if B_sol < B_struct + 20:
B_sol = B_struct + 20
return B_struct.item(), B_sol.item(), k.item()
[docs]
def get_structure_factors(self, as_complex: bool = False) -> torch.Tensor:
"""
Get structure factors, optionally as complex numbers.
Parameters
----------
as_complex : bool, optional
If True and phases available, return F*exp(i*phi). Default is False.
Returns
-------
torch.Tensor
Structure factor amplitudes or complex structure factors.
Raises
------
ValueError
If no amplitude data is loaded.
"""
if self.F is None:
raise ValueError("No amplitude data loaded")
if as_complex and self.phase is not None:
return self.F * torch.exp(1j * self.phase)
else:
return self.F
[docs]
def get_structure_factors_with_sigma(
self,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Get structure factor amplitudes and their uncertainties.
Returns
-------
F : torch.Tensor
Structure factor amplitudes of shape (N,).
F_sigma : torch.Tensor or None
Uncertainties of shape (N,), or None if not available.
Raises
------
ValueError
If no amplitude data is loaded.
Examples
--------
Get amplitudes with uncertainties::
F, sigma_F = data.get_structure_factors_with_sigma()
if sigma_F is not None:
weighted_residual = (F_obs - F_calc) / sigma_F
"""
if self.F is None:
raise ValueError("No amplitude data loaded")
return self.F, self.F_sigma
[docs]
def get_hkl(self):
"""
Return Miller indices for valid reflections.
Returns
-------
torch.Tensor
Miller indices of shape (N, 3), dtype int32.
Raises
------
ValueError
If no Miller indices are loaded.
"""
if self.hkl is None:
raise ValueError("No Miller indices loaded")
return self.hkl[self.masks()]
[docs]
def filter_by_resolution(
self, d_min: Optional[float] = None, d_max: Optional[float] = None
) -> "ReflectionData":
"""
Filter reflections by resolution range.
Adds a boolean mask to self.masks for the specified resolution range.
Parameters
----------
d_min : float, optional
Minimum resolution / high resolution cutoff (e.g., 1.5 Å).
d_max : float, optional
Maximum resolution / low resolution cutoff (e.g., 50.0 Å).
Returns
-------
ReflectionData
Self, for method chaining.
"""
if self.resolution is None:
self._calculate_resolution()
mask = torch.ones(len(self.hkl), dtype=torch.bool, device=self.device)
if d_min is not None:
mask &= self.resolution >= d_min
if d_max is not None:
mask &= self.resolution <= d_max
self.masks["resolution"] = mask
valid = self.masks().sum().item()
print(
f"Filtering: {mask.sum()}/{len(mask)} reflections in range "
f"[{d_max if d_max else 'inf'} - {d_min if d_min else 'inf'}] "
f"\u00c5 ({valid} valid after all masks)"
)
return self
[docs]
def get_mask(self):
"""
Return combined mask from all active filters.
Returns
-------
torch.Tensor
Boolean mask combining all filter conditions.
"""
[docs]
def cut_res(
self, highres: Optional[float] = None, lowres: Optional[float] = None
) -> "ReflectionData":
"""
Filter reflections by resolution range.
Alias for filter_by_resolution with more intuitive naming.
Parameters
----------
highres : float, optional
High resolution cutoff (small d-spacing, e.g., 1.5 Å).
Keeps reflections with d >= highres.
lowres : float, optional
Low resolution cutoff (large d-spacing, e.g., 50.0 Å).
Keeps reflections with d <= lowres.
Returns
-------
ReflectionData
Self, for method chaining.
Examples
--------
Filter by resolution::
# Keep reflections between 50 Å and 1.5 Å
filtered = data.cut_res(highres=1.5, lowres=50.0)
# Keep only high-resolution data (< 2 Å)
high_res = data.cut_res(highres=1.0, lowres=2.0)
"""
return self.filter_by_resolution(d_min=highres, d_max=lowres)
[docs]
def get_rfree_masks(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Get boolean masks for work and test (free) sets.
Returns
-------
work_mask : torch.Tensor or None
Boolean tensor for work set (flag != 0).
test_mask : torch.Tensor or None
Boolean tensor for test/free set (flag == 0).
Both are None if no R-free flags are available.
Examples
--------
Separate work and test sets::
work_mask, test_mask = data.get_rfree_masks()
if work_mask is not None:
F_work = data.F[work_mask]
F_test = data.F[test_mask]
"""
if self.rfree_flags is None:
return None, None
work_mask = self.rfree_flags != 0
test_mask = self.rfree_flags == 0
return work_mask, test_mask
[docs]
def get_work_set(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Get structure factors for the work set (R-free flag != 0).
Returns
-------
F_work : torch.Tensor
Structure factors for work set.
sigma_work : torch.Tensor or None
Uncertainties for work set, or None if not available.
Notes
-----
Returns full dataset with warning if no R-free flags available.
"""
if self.rfree_flags is None:
print("WARNING: No R-free flags available, returning full dataset")
return self.F, self.F_sigma
work_mask = self.rfree_flags != 0
F_work = self.F[work_mask] if self.F is not None else None
sigma_work = self.F_sigma[work_mask] if self.F_sigma is not None else None
return F_work, sigma_work
[docs]
def get_test_set(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Get structure factors for the test set (R-free flag == 0).
Returns
-------
F_test : torch.Tensor
Structure factors for test/free set.
sigma_test : torch.Tensor or None
Uncertainties for test set, or None if not available.
Raises
------
ValueError
If no R-free flags are available.
"""
if self.rfree_flags is None:
raise ValueError("No R-free flags available in dataset")
test_mask = self.rfree_flags == 0
F_test = self.F[test_mask] if self.F is not None else None
sigma_test = self.F_sigma[test_mask] if self.F_sigma is not None else None
return F_test, sigma_test
[docs]
def get_max_res(self) -> Optional[float]:
"""
Return maximum resolution (lowest d-spacing).
Returns
-------
float
Maximum resolution in Ångströms.
"""
if self.resolution is None:
self._calculate_resolution()
mask = self.masks()
return float(self.resolution[mask].min().item())
[docs]
def get_min_res(self) -> Optional[float]:
"""
Return minimum resolution (highest d-spacing).
Returns
-------
float
Minimum resolution in Ångströms.
"""
if self.resolution is None:
self._calculate_resolution()
mask = self.masks()
return float(self.resolution[mask].max().item())
[docs]
def __len__(self) -> int:
"""
Return number of reflections.
Returns
-------
int
Number of reflections in the dataset.
"""
return len(self.hkl) if self.hkl is not None else 0
@property
def d_min(self) -> Optional[float]:
"""
Return maximum resolution (lowest d-spacing).
Returns
-------
float
Maximum resolution in Ångströms.
"""
return self.get_max_res()
[docs]
def __repr__(self) -> str:
"""
Return string representation.
Returns
-------
str
Summary of reflection data including count, sources, and resolution.
"""
if self.hkl is None:
return "ReflectionData(empty)"
parts = [f"ReflectionData(n={len(self.hkl)}"]
if self.amplitude_source:
parts.append(f"F={self.amplitude_source}")
if self.phase_source:
parts.append(f"φ={self.phase_source}")
if self.resolution is not None:
parts.append(f"d={self.resolution.min():.2f}-{self.resolution.max():.2f}Å")
parts.append(f"sg={self.spacegroup}")
return ", ".join(parts) + ")"
[docs]
def get_valid_mask(self) -> torch.Tensor:
"""
Return the combined validity mask for all reflections.
This is the mask used to filter reflections in forward(). True indicates
a valid (included) reflection, False indicates an excluded one.
Returns
-------
torch.Tensor
Boolean mask of shape (N,) where N is the total number of reflections.
True = valid/included, False = invalid/excluded.
Examples
--------
Check validity::
mask = data.get_valid_mask()
print(f"{mask.sum()} of {len(mask)} reflections are valid")
"""
return self.masks()
[docs]
def data_indexed(
self,
) -> Tuple[
torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]
]:
"""
Return reflection data as indexed (filtered) tensors.
This method filters out invalid reflections and returns smaller tensors
containing only valid data. Useful for operations that don't support
MaskedTensors or for writing output files.
Returns
-------
hkl : torch.Tensor
Miller indices of shape (M, 3) where M is number of valid reflections.
F : torch.Tensor
Structure factor amplitudes of shape (M,).
F_sigma : torch.Tensor or None
Uncertainties of shape (M,) or None.
rfree_flags : torch.Tensor or None
R-free flags of shape (M,) or None.
See Also
--------
forward : Main method returning MaskedTensors.
Examples
--------
Get indexed data for file writing::
hkl, F, sigma, rfree = data.data_indexed()
F_np = F.cpu().numpy() # Safe for writing to files
"""
to_mask = self.masks()
hkl = self.hkl[to_mask]
F = self.F[to_mask]
F_sigma = self.F_sigma[to_mask] if self.F_sigma is not None else None
rfree_flags = (
self.rfree_flags[to_mask].to(torch.bool)
if self.rfree_flags is not None
else None
)
return hkl, F, F_sigma, rfree_flags
[docs]
def __call__(
self, mask: bool = True, scale: bool = True
) -> Tuple[torch.Tensor, "MaskedTensor", "MaskedTensor", torch.Tensor]:
"""
Return core reflection data with MaskedTensors for F and sigma.
F and F_sigma are returned as MaskedTensors which keep all reflections
but mark invalid ones as masked. Aggregation operations (sum, mean, etc.)
automatically skip masked values. HKL and rfree_flags remain regular tensors.
Parameters
----------
mask : bool, optional
If True, apply current masks to F and sigma. Default is True.
scale : bool, optional
If True, apply scaling to F and sigma before returning.
Returns
-------
hkl : torch.Tensor
Miller indices of shape (N, 3). Full size, unfiltered.
F : MaskedTensor
Structure factor amplitudes of shape (N,) with invalid reflections masked.
F_sigma : MaskedTensor or None
Uncertainties of shape (N,) with invalid reflections masked, or None.
rfree_flags : torch.Tensor or None
R-free flags of shape (N,) or None. Full size, unfiltered. 1=work, 0=free.
Notes
-----
MaskedTensors:
- Are PyTorch tensors with an associated boolean mask
- Aggregations (sum, mean, etc.) skip masked values automatically
- Element-wise operations preserve the mask
- Use .get_data() and .get_mask() to access underlying data
- Use .to_tensor(fill_value) to convert back to regular tensor
- Note: MaskedTensor is in prototype stage in PyTorch
Loss functions and targets extract valid data from MaskedTensors before
computation to work correctly with complex F_calc values.
Examples
--------
Access reflection data with MaskedTensors::
hkl, F, sigma, rfree = data()
print(F.shape) # Full shape
print(F.sum()) # Only sums valid (unmasked) values
# Access underlying data
valid_mask = F.get_mask()
F_values = F.get_data()[valid_mask]
"""
from torch.masked import MaskedTensor
hkl, F, F_sigma, rfree_flags = self.hkl, self.F, self.F_sigma, self.rfree_flags
if scale:
F, F_sigma = self.get_corrected_data()
if mask:
to_mask = self.masks()
if to_mask.sum() == 0:
raise RuntimeError(
"All reflections are masked! Check your filters/masks."
)
F = MaskedTensor(F.detach().clone(), to_mask)
if F_sigma is not None:
F_sigma = MaskedTensor(F_sigma.detach().clone(), to_mask)
return hkl, F, F_sigma, rfree_flags
[docs]
def data_fill_masked(
self, mode="mean"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return data tensors with missing or flagged reflections filled with.
args:
mode: str, optional
'mean' : fill missing/flagged with mean of present data
'zero' : fill missing/flagged with zero
"""
hkl, F, F_sigma, rfree = self()
if mode == "mean":
mean_F = self.mean_F_per_bin()
mean_F_sigma = self.mean_sigma_per_bin()
F_data = F.get_data().clone()
F_sigma_data = F_sigma.get_data().clone()
mask = F.get_mask()
F_data[~mask] = mean_F[self.bin_indices[~mask]]
F_sigma_data[~mask] = mean_F_sigma[self.bin_indices[~mask]]
rfree[~mask] = True # set missing to work set
return hkl, F_data, F_sigma_data, rfree
elif mode == "zero":
mask = F.get_mask()
F_data = F.get_data().clone()
F_sigma_data = F_sigma.get_data().clone()
F_data[~mask] = 0.0
F_sigma_data[~mask] = 0.0
rfree[~mask] = True # set missing to work set
return hkl, F_data, F_sigma_data, rfree
else:
raise ValueError(f"Unknown fill mode: {mode}")
[docs]
def __getitem__(self, key):
"""
Index into the reflection dataset.
Parameters
----------
key : torch.Tensor
Boolean mask or integer indices for selection.
Returns
-------
ReflectionData
New ReflectionData object with selected reflections.
"""
if isinstance(key, torch.Tensor):
return self.__select__(key)
raise TypeError(f"Unsupported index type: {type(key)}")
[docs]
def __select__(self, indices: torch.Tensor, op=None) -> "ReflectionData":
"""
Select reflections by boolean mask or integer indices.
Iterates over all dataclass fields generically, so new tensor fields
are handled automatically.
Parameters
----------
indices : torch.Tensor
Boolean mask of shape (N,) or integer indices for selection.
op : str, optional
Operation name for tracking purposes.
Returns
-------
ReflectionData
New ReflectionData object with selected reflections.
"""
from dataclasses import fields as dc_fields
from torchref.utils.utils import TensorMasks
n_refl = len(self.hkl) if self.hkl is not None else 0
# Create new instance with same device
selected = ReflectionData(verbose=self.verbose, device=self.device)
for f in dc_fields(self):
val = getattr(self, f.name)
if val is None:
continue
if isinstance(val, torch.Tensor):
if val.shape and val.shape[0] == n_refl:
setattr(selected, f.name, val[indices])
else:
# Non-matching tensor (e.g. U_aniso shape (6,)): copy as-is
setattr(selected, f.name, val.clone())
elif isinstance(val, Cell):
setattr(selected, f.name, val.clone())
else:
# Scalars, strings, None, gemmi objects, etc.
setattr(selected, f.name, val)
# Handle masks (not a dataclass field)
if hasattr(self, "masks") and self.masks is not None and len(self.masks) > 0:
new_masks = TensorMasks(device=self.device)
for name, mask_tensor in self.masks.items():
if mask_tensor is not None:
new_masks[name] = mask_tensor[indices]
selected.masks = new_masks
# Handle DataFrame
if hasattr(self, "dataset") and self.dataset is not None:
idx_np = indices.cpu().numpy()
selected.dataset = self.dataset.iloc[idx_np].copy()
selected.source = self
selected.last_op = op
return selected
[docs]
def sanitize_F(self):
"""
Remove invalid values from structure factors.
Adds a mask to filter out NaN, Inf, and non-positive values
from F and F_sigma.
"""
mask = torch.zeros(len(self.F), dtype=torch.bool, device=self.device)
if self.F is not None:
if self.verbose > 0:
print("found nan F values: ", torch.isnan(self.F).sum().item())
mask |= torch.isnan(self.F)
if self.F_sigma is not None:
if self.verbose > 0:
print(
"found nan F_sigma values: ", torch.isnan(self.F_sigma).sum().item()
)
mask |= torch.isnan(self.F_sigma)
neg_mask = self.F <= 0
if torch.any(neg_mask):
warnings.warn(
f"Found {neg_mask.sum().item()} non-positive F values, masking them out. This really should not happen!")
mask |= neg_mask
self.masks["sanity_F"] = ~mask
# Zero out invalid values so they can't leak NaN through autograd
# (masked indexing produces 0 gradients, but 0 * NaN = NaN in IEEE 754)
if mask.any():
self.F[mask] = 0.0
if self.F_sigma is not None:
self.F_sigma[mask] = 0.0
return self
[docs]
def check_all_data_types(self):
for key in self.__dict__:
if self.__dict__[key] is not None and isinstance(
self.__dict__[key], torch.Tensor
):
print(
f"{key}: {self.__dict__[key].dtype}, shape: {self.__dict__[key].shape}"
)
elif self.__dict__[key] is not None:
print(f"{key}: {type(self.__dict__[key])}, value: {self.__dict__[key]}")
else:
print(f"{key}: None")
[docs]
def validate_hkl(self, hkl_ref: torch.Tensor) -> "ReflectionData":
"""
Expand dataset to match a reference HKL set.
Reorders and expands the current dataset to match the reference HKL set.
Reflections present in the reference but missing from this dataset are
filled with placeholder values and masked out. This ensures all datasets
aligned to the same reference have identical shapes and can be processed
together without data loss from intersection operations.
Parameters
----------
hkl_ref : torch.Tensor
Reference Miller indices tensor of shape (N, 3), dtype int32.
This defines the canonical HKL ordering for all aligned datasets.
Returns
-------
ReflectionData
Self, modified in-place with expanded arrays matching hkl_ref.
Notes
-----
After calling this method:
- self.hkl will equal hkl_ref exactly
- All data arrays (F, F_sigma, rfree_flags, etc.) are reordered/expanded
- Missing reflections are filled with 0 (or appropriate defaults)
- A mask 'hkl_present' is added marking which reflections have real data
- forward() will return MaskedTensors that skip missing reflections
This approach avoids the problem where intersecting many datasets with
different outliers/missing reflections causes exponential data loss.
Examples
--------
Align multiple datasets to a common HKL set::
reference_hkl = data1.hkl.clone()
data1.validate_hkl(reference_hkl)
data2.validate_hkl(reference_hkl)
# Now data1 and data2 have identical shapes
assert data1.hkl.shape == data2.hkl.shape
"""
if self.hkl is None:
raise ValueError("No Miller indices loaded in ReflectionData")
if not isinstance(hkl_ref, torch.Tensor):
raise TypeError(f"hkl_ref must be a torch.Tensor, got {type(hkl_ref)}")
if hkl_ref.shape[-1] != 3:
raise ValueError(f"hkl_ref must have shape (N, 3), got {hkl_ref.shape}")
# Ensure hkl_ref is 2D and int32
if hkl_ref.dim() == 1:
hkl_ref = hkl_ref.unsqueeze(0)
hkl_ref = hkl_ref.to(dtype=dtypes.int, device=self.device)
n_ref = len(hkl_ref)
n_data = len(self.hkl)
# Build lookup from data HKL to index
# Use a dictionary with tuple keys for fast lookup
hkl_data_np = self.hkl.cpu().numpy()
data_hkl_to_idx = {tuple(hkl): idx for idx, hkl in enumerate(hkl_data_np)}
# For each reference HKL, find the corresponding data index (or -1 if missing)
hkl_ref_np = hkl_ref.cpu().numpy()
ref_to_data_idx = np.array(
[data_hkl_to_idx.get(tuple(hkl), -1) for hkl in hkl_ref_np], dtype=np.int64
)
# Create presence mask: True where data exists
presence_mask = torch.from_numpy(ref_to_data_idx >= 0).to(device=self.device)
valid_indices = torch.from_numpy(ref_to_data_idx).to(device=self.device)
# Helper to expand a tensor to reference size
def expand_tensor(tensor, fill_value=0.0):
if tensor is None:
return None
expanded = torch.full(
(n_ref,) + tensor.shape[1:],
fill_value,
dtype=tensor.dtype,
device=self.device,
)
# Copy existing data to correct positions
mask = valid_indices >= 0
expanded[mask] = tensor[valid_indices[mask]]
return expanded
# Expand all data arrays
old_F = self.F
old_F_sigma = self.F_sigma
old_I = self.I
old_I_sigma = getattr(self, "I_sigma", None)
old_rfree = self.rfree_flags
old_phase = getattr(self, "phase", None)
old_fom = getattr(self, "fom", None)
# Replace HKL with reference
self.hkl = hkl_ref
# Expand data tensors
self.F = expand_tensor(old_F, fill_value=0.0)
self.F_sigma = expand_tensor(old_F_sigma, fill_value=1.0)
if old_I is not None:
self.I = expand_tensor(old_I, fill_value=0.0)
if old_I_sigma is not None:
self.I_sigma = expand_tensor(old_I_sigma, fill_value=1.0)
# For rfree, default missing to work set (1)
if old_rfree is not None:
rfree_expanded = torch.ones(
n_ref, dtype=old_rfree.dtype, device=self.device
)
mask = valid_indices >= 0
rfree_expanded[mask] = old_rfree[valid_indices[mask]]
self.rfree_flags = rfree_expanded
# Recalculate resolution for new HKL set
self._calculate_resolution()
# Expand phase and fom if present
if old_phase is not None:
self.phase = expand_tensor(old_phase, fill_value=0.0)
if old_fom is not None:
self.fom = expand_tensor(old_fom, fill_value=0.0)
# Transfer existing masks to new indexing
old_masks = dict(self.masks.items())
# Clear existing masks
self.masks.clear()
self.masks._updated = True
for name, old_mask in old_masks.items():
if old_mask is not None and len(old_mask) == n_data:
# Expand mask: missing reflections are masked out (False)
new_mask = torch.zeros(n_ref, dtype=torch.bool, device=self.device)
mask = valid_indices >= 0
new_mask[mask] = old_mask[valid_indices[mask]]
self.masks[name] = new_mask
# Add presence mask - this is the key mask that marks real vs placeholder data
self.masks["hkl_present"] = presence_mask
n_present = presence_mask.sum().item()
n_missing = n_ref - n_present
if self.verbose > 0:
print("HKL validation (expand mode):")
print(f" Original dataset: {n_data} reflections")
print(f" Reference set: {n_ref} reflections")
print(f" Present in data: {n_present} ({100*n_present/n_ref:.1f}%)")
print(f" Missing (masked): {n_missing} ({100*n_missing/n_ref:.1f}%)")
return self
[docs]
def find_outliers(
self, model: "ModelFT", scaler, z_threshold: float = 4.0
) -> torch.Tensor:
"""
Identify outlier reflections based on log-ratio distribution.
Uses the fact that log(F_obs) - log(F_calc) should be normally distributed.
Outliers are reflections where |log_ratio - mean| > z_threshold * std_dev.
Parameters
----------
model : ModelFT
ModelFT object to compute structure factors.
scaler : Scaler
Scaler object to scale calculated structure factors.
z_threshold : float, optional
Z-score threshold to classify outliers. Default is 4.0.
Returns
-------
torch.Tensor
Boolean mask where True indicates outliers.
"""
hkl, F_obs, _, _ = self(mask=False)
log_ratio = self.get_log_ratio(model, scaler)
eps = 1e-10
# Remove any infinite or NaN values for statistics
valid_mask = torch.isfinite(log_ratio)
if valid_mask.sum() == 0:
if self.verbose > 0:
print("Warning: No valid log-ratios found for outlier detection")
return torch.zeros_like(F_obs, dtype=torch.bool, device=self.device)
to_use = valid_mask
log_ratio_valid = log_ratio[to_use]
# Compute mean and standard deviation of log-ratio distribution
mean_log_ratio = torch.mean(log_ratio_valid)
std_log_ratio = torch.std(log_ratio_valid, unbiased=True)
# Identify outliers using Z-score criterion
z_scores = torch.abs(log_ratio - mean_log_ratio) / (std_log_ratio + eps)
outlier_mask = z_scores > z_threshold
# Set invalid ratios as outliers too
outlier_mask = outlier_mask | ~valid_mask
if self.verbose > 0:
n_outliers = outlier_mask.sum().item()
n_total = len(F_obs)
print(
f"Outlier detection: {n_outliers}/{n_total} ({100*n_outliers/n_total:.2f}%) outliers found"
)
print(
f" Log-ratio statistics: mean={mean_log_ratio:.3f}, std={std_log_ratio:.3f}"
)
print(f" Z-score threshold: {z_threshold:.1f}")
# Ensure outlier_mask is on correct device and register
outlier_mask = outlier_mask.to(self.device)
self.masks["outliers"] = ~outlier_mask
if self.verbose > 0:
print(
f"Outlier detection: {outlier_mask.sum().item()} reflections flagged as outliers out of {len(outlier_mask)}."
)
[docs]
def get_log_ratio(self, model: "ModelFT", scaler) -> torch.Tensor:
"""
Compute log-ratio between observed and calculated structure factors.
Parameters
----------
model : ModelFT
ModelFT object to compute structure factors.
scaler : Scaler
Scaler object to scale calculated structure factors.
Returns
-------
torch.Tensor
Log-ratio values: log(F_obs) - log(F_calc).
"""
# Get observed and calculated structure factors
eps = 1e-6
hkl, F_obs, _, _ = self(mask=False)
F_calc_complex = model.forward(hkl) # Complex structure factors
F_calc_scaled = torch.abs(
scaler(F_calc_complex, use_mask=False)
) # Scaled amplitudes
# Avoid log of zero by adding small epsilon
F_obs_safe = torch.clamp(F_obs, min=eps)
F_calc_safe = torch.clamp(F_calc_scaled, min=eps)
# Compute log-ratio distribution: log(F_obs) - log(F_calc)
log_ratio = torch.log(F_obs_safe) - torch.log(F_calc_safe)
return log_ratio
[docs]
def get_outlier_statistics(self) -> Dict:
"""
Get statistics about flagged outliers.
Returns
-------
dict
Dictionary containing:
- n_outliers : int
- n_total : int
- fraction_outliers : float
- detection_params : dict or None
- outlier_resolution_stats : dict (if resolution available)
"""
if self.outlier_flags is None:
return {"n_outliers": 0, "n_total": 0, "fraction_outliers": 0.0}
n_outliers = self.outlier_flags.sum().item()
n_total = len(self.outlier_flags)
stats = {
"n_outliers": n_outliers,
"n_total": n_total,
"fraction_outliers": n_outliers / n_total if n_total > 0 else 0.0,
"detection_params": self.outlier_detection_params,
}
if self.resolution is not None:
# Add resolution-dependent statistics
outlier_resolutions = (
self.resolution[self.outlier_flags]
if n_outliers > 0
else torch.tensor([])
)
if len(outlier_resolutions) > 0:
stats["outlier_resolution_stats"] = {
"min": outlier_resolutions.min().item(),
"max": outlier_resolutions.max().item(),
"mean": outlier_resolutions.mean().item(),
"median": outlier_resolutions.median().item(),
}
return stats
[docs]
def unpack_one(self):
"""
Unpack one level of source.
Does not recurse fully and does not flag.
Returns
-------
ReflectionData
Parent source or self if no source.
"""
if self.source is not None:
return self.source
return self
[docs]
def flag_suspicious_sigma(self, z_threshold: float = 5.0) -> None:
"""
Flag sigma values that deviate significantly from expected distribution.
Sigma values from a detector should follow a log-normal distribution.
Values with z-scores beyond threshold are flagged as suspicious.
Parameters
----------
z_threshold : float, optional
Z-score threshold for flagging outliers. Default is 5.0.
"""
sigmas = self.F_sigma
log_sigmas = torch.log(sigmas)
flagged_initial = torch.isnan(log_sigmas) | torch.isinf(log_sigmas)
mean_log_sigma = torch.mean(log_sigmas[~flagged_initial])
std_log_sigma = torch.std(log_sigmas[~flagged_initial]) + 1e-5 * mean_log_sigma
z_scores = (log_sigmas - mean_log_sigma) / std_log_sigma
flagged = torch.abs(z_scores) > z_threshold
flagged = flagged | flagged_initial
if self.verbose > 0:
n_flagged = flagged.sum().item()
n_total = len(sigmas)
print(
f"Suspicious sigma detection: {n_flagged}/{n_total} ({100*n_flagged/n_total:.2f}%) reflections flagged"
)
self.masks["flagged_sigma"] = ~flagged
[docs]
def dump(self):
"""
Dump all reflection data to console for debugging.
Prints type, shape, and device information for all attributes.
"""
print("ReflectionData dump:")
for key in self.__dict__:
value = self.__dict__[key]
if isinstance(value, torch.Tensor):
print(
f" {key}: dtype={value.dtype}, shape={value.shape}, device={value.device}"
)
else:
print(f" {key}: type={type(value)}, value={value}")
[docs]
def write_mtz(
self,
fname: str,
fcalc: Optional[torch.Tensor] = None,
model_ft: Optional["ModelFT"] = None,
) -> None:
"""
Write reflection data to MTZ file with optional map coefficients.
Parameters
----------
fname : str
Output MTZ filename.
fcalc : torch.Tensor, optional
Complex calculated structure factors of shape (N,).
If provided, computes phases and map coefficients.
model_ft : ModelFT, optional
ModelFT object to compute fcalc if not provided.
Notes
-----
The MTZ file will contain canonical column names:
- FP, SIGFP: Observed amplitudes and uncertainties
- I, SIGI: Observed intensities and uncertainties (if available)
- FreeR_flag: R-free test set flags
- FWT, PHWT: 2mFo-DFc map coefficients (if fcalc provided)
- DELFWT, PHDELWT: mFo-DFc map coefficients (if fcalc provided)
Map coefficients are computed as:
- 2mFo-DFc: 2*Fo - Fc (filled to resolution limit)
- mFo-DFc: Fo - Fc
Examples
--------
Write MTZ with map coefficients::
data = ReflectionData().load_mtz('observed.mtz')
model = Model().load_pdb('model.pdb')
model_ft = ModelFT(model, data.cell, data.spacegroup)
fcalc = model_ft.forward(data.hkl)
data.write_mtz('output.mtz', fcalc=fcalc)
"""
from torchref.io.mtz import write
# Convert data to numpy for DataFrame creation
hkl_np = self.hkl.detach().cpu().numpy()
# Create DataFrame with HKL indices
data_dict = {
"H": hkl_np[:, 0],
"K": hkl_np[:, 1],
"L": hkl_np[:, 2],
}
# Add observed amplitudes (canonical names: FP, SIGFP)
if self.F is not None:
data_dict["F-obs"] = self.F.detach().cpu().numpy()
if self.F_sigma is not None:
data_dict["SIGF-obs"] = self.F_sigma.detach().cpu().numpy()
# Add observed intensities (canonical names: I, SIGI)
if self.I is not None:
data_dict["I-obs"] = self.I.detach().cpu().numpy()
if self.I_sigma is not None:
data_dict["SIGI-obs"] = self.I_sigma.detach().cpu().numpy()
# Add R-free flags (canonical name: FreeR_flag)
if self.rfree_flags is not None:
data_dict["R-free-flags"] = (
self.rfree_flags.detach().cpu().numpy().astype(int)
)
# Compute fcalc if model_ft is provided but fcalc is not
if fcalc is None and model_ft is not None:
fcalc = model_ft.forward(self.hkl)
mask = self.masks().detach().cpu().numpy()
# Add map coefficients if fcalc is provided
if fcalc is not None:
# Ensure fcalc is complex
if not torch.is_complex(fcalc):
raise ValueError("fcalc must be a complex tensor")
# Convert to numpy
fcalc_np = fcalc.detach().cpu().numpy()
F_obs = self.F.detach().cpu().numpy()
# Compute phases in degrees
phases = np.angle(fcalc_np, deg=True)
F_calc_amp = np.abs(fcalc_np)
# Compute map coefficients
# 2mFo-DFc: Use observed amplitudes with calculated phases
# When 2*Fobs - Fcalc < 0, flip phase by 180° and use absolute amplitude
two_mfo_dfc_raw = 2.0 * F_obs - F_calc_amp
two_mfo_dfc_amp = np.abs(two_mfo_dfc_raw)
two_mfo_dfc_phase = phases.copy()
# mFo-DFc: Difference map
mfo_dfc_complex = F_obs * np.exp(1j * np.deg2rad(phases)) - fcalc_np
mfo_dfc_complex[~mask] = 0.0 # Zero out reflections outside mask
mfo_dfc_amp = np.abs(mfo_dfc_complex)
mfo_dfc_phase = np.angle(mfo_dfc_complex, deg=True)
# Add 2mFo-DFc map coefficients (standard Coot names: FWT, PHWT)
data_dict["FWT"] = two_mfo_dfc_amp
data_dict["PHWT"] = two_mfo_dfc_phase
# Add mFo-DFc map coefficients (standard Coot names: DELFWT, PHDELWT)
data_dict["DELFWT"] = mfo_dfc_amp
data_dict["PHDELWT"] = mfo_dfc_phase
data_dict["F-model"] = F_calc_amp
data_dict["PH-model"] = phases
if self.verbose > 0:
print("Added map coefficients:")
print(" 2mFo-DFc: FWT, PHWT")
print(" mFo-DFc: DELFWT, PHDELWT")
print(
f" Resolution range: {self.resolution.min().item():.2f} - {self.resolution.max().item():.2f} Å"
)
# Create DataFrame
df = pd.DataFrame(data_dict)
# Write MTZ file
write(df, self.cell.data, self.spacegroup, fname)
if self.verbose > 0:
print(f"✓ Wrote MTZ file: {fname}")
print(f" Reflections: {len(df)}")
print(f" Columns: {', '.join(df.columns)}")
@property
def centric(self):
"""
Get boolean mask for centric reflections (full size, unfiltered).
Calculates it if not already present. Returns unfiltered centric flags
matching the full HKL array size, consistent with how forward() returns
full-size arrays when using MaskedTensors.
Returns
-------
torch.Tensor or None
Boolean tensor of shape (N,) where N is total reflections.
True indicates centric reflection, False indicates acentric.
"""
if self.hkl is None:
return None
# Check if we already have it cached (could be stored in a buffer if we want persistence)
if not hasattr(self, "_centric_flags") or self._centric_flags is None:
from torchref.base.french_wilson import is_centric_from_hkl
# Ensure we have spacegroup
sg = self.spacegroup if self.spacegroup else "P1"
self._centric_flags = is_centric_from_hkl(self.hkl, sg)
# Return full-size centric flags (no filtering)
return self._centric_flags
[docs]
def calc_patterson(
self,
grid_size: Optional[Tuple[int, int, int]] = None,
grid_sampling: Optional[float] = 1,
) -> torch.Tensor:
"""
Calculate Patterson map of the dataset.
The Patterson function P(u,v,w) = Σ|F(hkl)|² exp(-2πi(hu+kv+lw))
is computed via inverse FFT of F². Data is expanded to P1 symmetry
using only observed reflections (no filling of missing data).
Parameters
----------
grid_size : tuple of int, optional
Grid dimensions (Nx, Ny, Nz). If None, automatically determined
from unit cell and resolution.
grid_sampling : float, optional
Sampling interval for the grid. Default is 1.
This sets the grid so that we sample twice as much as normal for a given resolution
Returns
-------
torch.Tensor
Real-valued Patterson map of shape (Nx, Ny, Nz).
Origin is at grid position [0, 0, 0].
"""
from torchref.base.fourier import find_grid_size
from torchref.base.reciprocal import place_on_grid
# Expand to P1 symmetry (don't fill missing reflections - use only observed data)
data = self.expand_to_p1()
max_res = data.resolution.min() * grid_sampling
if grid_size is None:
grid_size = find_grid_size(data.cell, max_res)
# Use data_indexed to get only valid (observed) reflections
hkl, F, _, _ = data.data_indexed()
F_2 = F**2
# Place F² on reciprocal grid (don't enforce Hermitian since we have P1 expansion)
grid = place_on_grid(hkl, F_2, grid_size, enforce_hermitian=False)
patterson = torch.fft.ifftn(grid, dim=(0, 1, 2), norm="forward").real
return patterson
[docs]
def possible_hkl(self) -> torch.Tensor:
"""
Generate all possible HKL indices within the resolution limit.
Returns
-------
torch.Tensor
Tensor of shape (M, 3) containing all possible Miller indices
within the resolution limit defined by self.resolution.
"""
from torchref.base.reciprocal import generate_possible_hkl
if self.cell is None or self.resolution is None:
raise ValueError(
"Cell and resolution must be defined to generate possible HKL"
)
max_res = self.resolution.min().item()
possible_hkl = generate_possible_hkl(self.cell.data, max_res, device=self.device)
return possible_hkl
[docs]
def remap(
self,
new_hkl: torch.Tensor,
index_mapping: torch.Tensor,
phase_shifts: Optional[torch.Tensor] = None,
spacegroup=None,
op_name: str = "remap",
) -> "ReflectionData":
"""
Create new ReflectionData with remapped HKL set and data.
This is the core method for index-based transformations of reflection
data. It handles remapping all tensor fields based on an index mapping,
with support for missing reflections (indicated by -1 in index_mapping).
Parameters
----------
new_hkl : torch.Tensor, shape (M, 3)
New Miller indices.
index_mapping : torch.Tensor, shape (M,), dtype int64
Maps new indices to original: ``new[i] = old[index_mapping[i]]``
Values of -1 indicate missing reflections (filled with defaults).
phase_shifts : torch.Tensor, optional, shape (M,)
Phase offsets to apply (e.g., from symmetry translations).
spacegroup : str, int, gemmi.SpaceGroup, or None
New spacegroup. If None, keeps original.
op_name : str
Operation name for provenance tracking.
Returns
-------
ReflectionData
New object with remapped data. Missing reflections get:
- 0.0 for F, I, phase, fom
- 1.0 for F_sigma, I_sigma (conservative uncertainty)
- True for masks['missing']
"""
from torchref.symmetry.spacegroup import SpaceGroup
# Helper function for remapping tensors with missing handling
def _remap_tensor(tensor, fill_value):
if tensor is None:
return None
valid_mask = index_mapping >= 0
result = torch.full(
(len(new_hkl),) + tensor.shape[1:],
fill_value,
dtype=tensor.dtype,
device=self.device,
)
result[valid_mask] = tensor[index_mapping[valid_mask]]
return result
# Create new ReflectionData
remapped = ReflectionData(verbose=self.verbose, device=self.device)
# Set new HKL
remapped.hkl = new_hkl.to(device=self.device)
# Remap amplitude and intensity fields
remapped.F = _remap_tensor(self.F, fill_value=0.0)
remapped.F_sigma = _remap_tensor(self.F_sigma, fill_value=1.0)
remapped.I = _remap_tensor(self.I, fill_value=0.0)
remapped.I_sigma = _remap_tensor(self.I_sigma, fill_value=1.0)
remapped.fom = _remap_tensor(self.fom, fill_value=0.0)
# Carry forward prior mask if available
prior_mask = self.masks()
if prior_mask is not None:
remapped.masks["prior_flagged"] = _remap_tensor(
prior_mask.to(dtype=dtypes.int), fill_value=0
).to(torch.bool)
# Handle rfree_flags (True = include in Rfree for missing)
if self.rfree_flags is not None:
remapped.rfree_flags = _remap_tensor(
self.rfree_flags.to(dtype=dtypes.int), fill_value=1
).to(torch.bool)
# Handle phases with optional shifts
if self.phase is not None:
remapped.phase = _remap_tensor(self.phase, fill_value=0.0)
if phase_shifts is not None:
remapped.phase = remapped.phase + phase_shifts.to(device=self.device)
elif phase_shifts is not None:
# Store phase shifts even if no original phases
remapped._expansion_phase_shifts = phase_shifts.to(device=self.device)
# Clone cell
remapped.cell = self.cell.clone() if self.cell is not None else None
# Set spacegroup
if spacegroup is not None:
remapped.spacegroup = SpaceGroup(spacegroup)
else:
remapped.spacegroup = self.spacegroup
# Recalculate resolution
if remapped.cell is not None and remapped.hkl is not None:
remapped._calculate_resolution()
# Invalidate dependent fields
remapped.bin_indices = None
# Copy metadata sources
remapped.amplitude_source = self.amplitude_source
remapped.intensity_source = self.intensity_source
remapped.phase_source = self.phase_source
remapped.rfree_source = self.rfree_source
# Track provenance
remapped.source = self
remapped.last_op = op_name
# Add missing mask
missing_mask = index_mapping < 0
if missing_mask.any():
remapped.masks["missing"] = ~missing_mask.to(device=self.device)
return remapped
[docs]
def fill(self, d_min: Optional[float] = None) -> "ReflectionData":
"""
Fill missing reflections within resolution limit.
Generates all possible reflections for the current spacegroup within
the resolution limit, identifies which are missing, and creates a
complete dataset. Missing reflections are filled with default values.
Parameters
----------
d_min : float, optional
High resolution limit in Angstroms. If None, uses the minimum
resolution from the current dataset.
Returns
-------
ReflectionData
New ReflectionData with complete set of reflections.
Missing reflections have:
- F, I, phase, fom = 0.0
- F_sigma, I_sigma = 1.0
- masks['missing'] = True
Examples
--------
Fill to completeness::
data = ReflectionData().load_mtz('data.mtz')
data_filled = data.fill(d_min=2.0)
print(f"Original: {len(data)}, Filled: {len(data_filled)}")
"""
from torchref.symmetry.reciprocal_symmetry import complete_hkl
if self.hkl is None:
raise ValueError("ReflectionData has no Miller indices loaded")
if self.cell is None:
raise ValueError("ReflectionData has no unit cell defined")
# Use current resolution limit if not specified
if d_min is None:
if self.resolution is None:
raise ValueError("Resolution not available - specify d_min")
d_min = self.resolution.min().item()
# Get complete HKL set with index mapping
filled_hkl, indices, missing = complete_hkl(
self.hkl, self.cell.data, self.spacegroup or "P1", d_min, device=self.device
)
# Use remap to create the new dataset
remapped = self.remap(
new_hkl=filled_hkl,
index_mapping=indices,
spacegroup=self.spacegroup, # Keep same spacegroup
op_name=f"fill(d_min={d_min:.2f})",
)
return remapped
[docs]
def expand_to_p1(
self, include_friedel: bool = True, remove_absences: bool = True
) -> "ReflectionData":
"""
Expand reflection data from asymmetric unit to P1.
Applies all symmetry operations from the current space group to generate
all symmetry-equivalent reflections. Returns a NEW ReflectionData object
with expanded reflections; does not modify self.
Parameters
----------
include_friedel : bool, default True
Include Friedel mates (-h, -k, -l). For normal (non-anomalous)
scattering, Friedel pairs have identical amplitudes.
remove_absences : bool, default True
Remove systematically absent reflections from output.
Returns
-------
ReflectionData
New ReflectionData object with:
- All symmetry-equivalent reflections
- spacegroup set to P1
- Expanded tensor fields (F, F_sigma, I, I_sigma, phase, fom, rfree_flags)
- Recalculated resolution values
- Provenance tracking via source and last_op attributes
Notes
-----
The expansion handles tensor fields as follows:
- hkl: Symmetry operations applied, Friedel mates added, duplicates removed
- F, F_sigma, I, I_sigma, fom, rfree_flags: Indexed from original
- phase: Indexed + phase shift from translation component
- resolution: Recalculated from expanded hkl + cell
- bin_indices: Cleared (invalidated by expansion)
Examples
--------
Expand to P1 for calculations::
# Load data in original space group
data = ReflectionData().load_mtz('data.mtz')
print(f"Original: {len(data)} reflections, {data.spacegroup}")
# Expand to P1
data_p1 = data.expand_to_p1()
print(f"Expanded: {len(data_p1)} reflections, {data_p1.spacegroup}")
# Without Friedel mates (for anomalous data)
data_p1_anom = data.expand_to_p1(include_friedel=False)
"""
from torchref.symmetry.reciprocal_symmetry import expand_hkl
if self.hkl is None:
raise ValueError("ReflectionData has no Miller indices loaded")
# Get expanded HKL set with index mapping and phase shifts
hkl_p1, indices, phase_shifts = expand_hkl(
self.hkl,
self.spacegroup or "P1",
include_friedel=include_friedel,
remove_absences=remove_absences,
device=self.device,
)
# Use remap to create the new dataset
return self.remap(
new_hkl=hkl_p1,
index_mapping=indices,
phase_shifts=phase_shifts,
spacegroup="P1",
op_name=f"expand_to_p1(include_friedel={include_friedel})",
)
[docs]
def reduce_to_spacegroup(
self, spacegroup, include_friedel: bool = True, aggregation: str = "mean"
) -> "ReflectionData":
"""
Reduce P1 reflection data to asymmetric unit of a target spacegroup.
This is the inverse of expand_to_p1(). Takes reflection data in P1 and
merges symmetry-equivalent reflections into single ASU reflections using
the specified aggregation function.
Parameters
----------
spacegroup : str, int, or gemmi.SpaceGroup
Target space group specification.
include_friedel : bool, default True
If True, also merge Friedel mates when reducing.
aggregation : str, default 'mean'
Aggregation function for merging equivalent reflections:
- 'mean': Average values (default, good for amplitudes)
- 'sum': Sum values
- 'first': Take first valid value (no averaging)
Returns
-------
ReflectionData
New ReflectionData with merged reflections in the target spacegroup.
Notes
-----
Field handling during reduction:
- F, I: Aggregated using specified function
- F_sigma, I_sigma: Propagated as sqrt(sum(sigma^2) / n) for 'mean'
- phase: Aggregated via complex averaging (handles phase wrapping)
- fom: Weighted by amplitude during phase averaging
- rfree_flags: OR operation (True if any equivalent is True)
Examples
--------
Reduce symmetry-expanded data::
# Expand to P1 for calculations, then reduce back
data_p1 = data.expand_to_p1()
# ... modify F_p1 ...
data_merged = data_p1.reduce_to_spacegroup('P21')
# Reduce with sum instead of mean
data_summed = data_p1.reduce_to_spacegroup('P21', aggregation='sum')
"""
from torchref.symmetry.reciprocal_symmetry import reduce_hkl
from torchref.symmetry.spacegroup import SpaceGroup
if self.hkl is None:
raise ValueError("ReflectionData has no Miller indices loaded")
# Get reduction mapping
hkl_asu, reduction_indices, phase_shifts = reduce_hkl(
self.hkl, spacegroup, include_friedel=include_friedel, device=self.device
)
n_asu = len(hkl_asu)
n_equiv = reduction_indices.shape[1]
valid_mask = reduction_indices >= 0 # (n_asu, n_equiv)
count_valid = valid_mask.sum(dim=1).clamp(min=1).float() # (n_asu,)
# Helper function for aggregating 1D tensors
def _aggregate_tensor(tensor, agg_func="mean", fill_value=0.0):
if tensor is None:
return None
# Gather values: (n_asu, n_equiv)
# Use clamp(min=0) to avoid indexing errors, then mask invalid
gathered = tensor[reduction_indices.clamp(min=0)]
gathered = torch.where(valid_mask, gathered, torch.zeros_like(gathered))
if agg_func == "mean":
return gathered.sum(dim=1) / count_valid
elif agg_func == "sum":
return gathered.sum(dim=1)
elif agg_func == "first":
# Take first valid value
first_valid_idx = valid_mask.to(dtype=dtypes.int).argmax(dim=1)
return gathered[
torch.arange(n_asu, device=self.device), first_valid_idx
]
else:
raise ValueError(f"Unknown aggregation: {agg_func}")
def _aggregate_sigma(tensor, agg_func="mean"):
"""Propagate uncertainty correctly for averaging."""
if tensor is None:
return None
# Gather values
gathered = tensor[reduction_indices.clamp(min=0)]
gathered = torch.where(valid_mask, gathered, torch.zeros_like(gathered))
if agg_func == "mean":
# For averaging: sigma_mean = sqrt(sum(sigma^2)) / n
variance_sum = (gathered**2).sum(dim=1)
return torch.sqrt(variance_sum) / count_valid
elif agg_func == "sum":
# For summing: sigma_sum = sqrt(sum(sigma^2))
variance_sum = (gathered**2).sum(dim=1)
return torch.sqrt(variance_sum)
elif agg_func == "first":
first_valid_idx = valid_mask.to(dtype=dtypes.int).argmax(dim=1)
return gathered[
torch.arange(n_asu, device=self.device), first_valid_idx
]
else:
raise ValueError(f"Unknown aggregation: {agg_func}")
# Create new ReflectionData
reduced = ReflectionData(verbose=self.verbose, device=self.device)
# Set HKL
reduced.hkl = hkl_asu.to(device=self.device)
# Aggregate amplitude and intensity fields
reduced.F = _aggregate_tensor(self.F, aggregation)
reduced.F_sigma = _aggregate_sigma(self.F_sigma, aggregation)
reduced.I = _aggregate_tensor(self.I, aggregation)
reduced.I_sigma = _aggregate_sigma(self.I_sigma, aggregation)
# Handle phases via complex averaging
if self.phase is not None:
# Gather phases and apply phase shifts for proper averaging
phases_gathered = self.phase[reduction_indices.clamp(min=0)]
phases_gathered = phases_gathered + phase_shifts
phases_gathered = torch.where(
valid_mask, phases_gathered, torch.zeros_like(phases_gathered)
)
# Get weights (amplitudes or FOM)
if self.fom is not None:
weights = self.fom[reduction_indices.clamp(min=0)]
elif self.F is not None:
weights = self.F[reduction_indices.clamp(min=0)]
else:
weights = torch.ones_like(phases_gathered)
weights = torch.where(valid_mask, weights, torch.zeros_like(weights))
# Complex averaging: mean of F*exp(i*phi) then extract angle
complex_sf = weights * torch.exp(1j * phases_gathered)
complex_mean = complex_sf.sum(dim=1) / count_valid
reduced.phase = torch.angle(complex_mean).float()
# FOM as magnitude of normalized mean complex vector
if self.fom is not None:
norm_weights = weights / weights.sum(dim=1, keepdim=True).clamp(
min=1e-10
)
unit_vectors = torch.exp(1j * phases_gathered)
mean_vector = (norm_weights * unit_vectors).sum(dim=1)
reduced.fom = torch.abs(mean_vector).float()
else:
reduced.phase = None
reduced.fom = (
_aggregate_tensor(self.fom, aggregation)
if self.fom is not None
else None
)
# Handle rfree_flags: OR operation (free if any equivalent is free)
if self.rfree_flags is not None:
rfree_gathered = self.rfree_flags[reduction_indices.clamp(min=0)].to(
dtypes.int
)
rfree_gathered = torch.where(
valid_mask,
rfree_gathered,
torch.ones_like(rfree_gathered), # Default to work set
)
# 0 = free, non-zero = work. Take min to get free if any is free.
reduced.rfree_flags = rfree_gathered.min(dim=1).values != 0
# Clone cell
reduced.cell = self.cell.clone() if self.cell is not None else None
# Set spacegroup
reduced.spacegroup = SpaceGroup(spacegroup)
# Recalculate resolution
if reduced.cell is not None and reduced.hkl is not None:
reduced._calculate_resolution()
# Invalidate dependent fields
reduced.bin_indices = None
# Copy metadata sources
reduced.amplitude_source = self.amplitude_source
reduced.intensity_source = self.intensity_source
reduced.phase_source = self.phase_source
reduced.rfree_source = self.rfree_source
# Track provenance
reduced.source = self
reduced.last_op = (
f"reduce_to_spacegroup({spacegroup}, aggregation={aggregation})"
)
return reduced
[docs]
def canonicalize(self, include_friedel: bool = True) -> "ReflectionData":
"""Return new ReflectionData with HKL in standard CCP4 ASU form.
Remaps all Miller indices to the canonical CCP4 asymmetric unit
representative using ``gemmi.ReciprocalAsu``, adjusts phases
accordingly, and sorts reflections lexicographically by (h, k, l).
Parameters
----------
include_friedel : bool, default True
Whether Friedel mates are considered equivalent.
Returns
-------
ReflectionData
New object with canonicalized, sorted Miller indices.
"""
from torchref.symmetry.reciprocal_symmetry import canonicalize_hkl
if self.hkl is None:
raise ValueError("ReflectionData has no Miller indices loaded")
canonical_hkl, phase_shifts, friedel_flags, sort_indices = canonicalize_hkl(
self.hkl, self.spacegroup or "P1", include_friedel, device=self.device
)
# Reorder all fields using __select__
result = self.__select__(sort_indices, op=f"canonicalize(include_friedel={include_friedel})")
# Overwrite HKL with canonical form (already sorted)
result.hkl = canonical_hkl
# Fix phases: phi_new = where(friedel, -phi_old, phi_old) + phase_shift
if result.phase is not None:
result.phase = torch.where(friedel_flags, -result.phase, result.phase) + phase_shifts
# Recalculate resolution from canonical HKL + cell
if result.cell is not None:
result._calculate_resolution()
# Invalidate bin_indices
result.bin_indices = None
return result
# ========== E-VALUE AND ANISOTROPY CORRECTION METHODS ==========
[docs]
def get_scattering_vectors(self) -> torch.Tensor:
"""
Get scattering vectors (s-vectors) from hkl and cell.
The s-vector for a reflection hkl is defined as:
s = B* @ hkl
where B* is the reciprocal basis matrix.
Returns
-------
s_vectors : torch.Tensor
Reciprocal space vectors in Angstroms^-1, shape (N, 3).
Raises
------
ValueError
If hkl or cell is not available.
"""
if self.hkl is None:
raise ValueError("No Miller indices loaded")
if self.cell is None:
raise ValueError("No unit cell defined")
return math_torch.get_scattering_vectors(self.hkl, self.cell.data)
[docs]
def get_radial_shells(
self,
n_shells: int = 20,
d_min: Optional[float] = None,
d_max: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create uniform radial shells in 1/d space for normalization.
This creates shells with uniform spacing in reciprocal space (1/d),
different from get_bins() which creates equal-count bins.
Parameters
----------
n_shells : int
Number of radial shells. Default is 20.
d_min : float, optional
High resolution limit in Angstroms. If None, uses dataset minimum.
d_max : float, optional
Low resolution limit in Angstroms. If None, uses dataset maximum.
Returns
-------
shell_edges : torch.Tensor
Shell boundaries in Angstroms^-1, shape (n_shells+1,).
shell_centers : torch.Tensor
Shell centers in Angstroms^-1, shape (n_shells,).
shell_indices : torch.Tensor
Shell index for each reflection, shape (N,). Values -1 for out-of-range.
"""
from torchref.base.normalization import (
assign_to_shells,
compute_radial_shells,
)
if self.resolution is None:
self._calculate_resolution()
# Get resolution limits
if d_min is None:
d_min = self.get_max_res()
if d_max is None:
d_max = self.get_min_res()
# Compute shells
shell_edges, shell_centers = compute_radial_shells(
d_min, d_max, n_shells, device=self.device
)
# Get s-vectors and magnitudes
s_vectors = self.get_scattering_vectors()
s_mag = torch.linalg.norm(s_vectors, dim=1)
# Assign to shells
shell_indices = assign_to_shells(s_mag, shell_edges)
# Cache shell indices
self.radial_shell_indices = shell_indices
return shell_edges, shell_centers, shell_indices
[docs]
def fit_anisotropy(
self,
n_shells: int = 20,
d_min: Optional[float] = None,
d_max: Optional[float] = None,
n_iterations: int = 100,
verbose: Optional[bool] = None,
) -> torch.Tensor:
"""
Fit anisotropy correction parameters to minimize CV within shells.
Optimizes U parameters so that corrected F² values have minimal
coefficient of variation within each resolution shell.
Parameters
----------
n_shells : int
Number of resolution shells for variance calculation.
d_min : float, optional
High resolution limit in Angstroms. If None, uses dataset minimum.
d_max : float, optional
Low resolution limit in Angstroms. If None, uses dataset maximum.
n_iterations : int
Number of optimization iterations.
verbose : bool, optional
Print progress. If None, uses self.verbose.
Returns
-------
U : torch.Tensor
Fitted anisotropy parameters [u11, u22, u33, u12, u13, u23], shape (6,).
Also stored in self.U_aniso.
Raises
------
ValueError
If no amplitude data is available.
"""
from torchref.base import fit_anisotropy_correction
if self.F is None:
raise ValueError("No amplitude data loaded")
if verbose is None:
verbose = self.verbose > 0
# Get F² values
F_squared = self.F**2
# Get s-vectors
s_vectors = self.get_scattering_vectors()
# Get resolution limits
if d_min is None:
d_min = self.get_max_res()
if d_max is None:
d_max = self.get_min_res()
# Fit anisotropy
U, final_cv = fit_anisotropy_correction(
F_squared,
s_vectors,
n_shells=n_shells,
d_min=d_min,
d_max=d_max,
n_iterations=n_iterations,
verbose=verbose,
)
# Store result
self.U_aniso = U
return U
[docs]
def setup_anisotropy(
self,
U_aniso: Optional[torch.Tensor] = None,
) -> None:
"""
Setup anisotropy correction parameters.
Parameters
----------
U_aniso : torch.Tensor, optional
Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).
If None, uses Initializes as zeros.
"""
if U_aniso is None:
U_aniso = torch.zeros(6, device=self.device, dtype=dtypes.float, requires_grad=False)
else:
U_aniso = U_aniso.to(device=self.device, dtype=dtypes.float, requires_grad=False)
self.U_aniso = U_aniso
return self
[docs]
def apply_anisotropy_correction(
self,
U_aniso: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply anisotropy correction to F² values.
Parameters
----------
U_aniso : torch.Tensor, optional
Anisotropic parameters [u11, u22, u33, u12, u13, u23], shape (6,).
If None, uses self.U_aniso (must have called fit_anisotropy first).
Returns
-------
F_corrected: torch.Tensor
Anisotropy-corrected F values, shape (N,).
sigma_F_corrected: torch.Tensor
Uncertainties of corrected F values, shape (N,).
Raises
------
ValueError
If no U parameters available and none provided.
"""
from torchref.base import apply_anisotropy_correction
if U_aniso is None:
U_aniso = self.U_aniso
if U_aniso is None:
raise ValueError(
"No anisotropy parameters available. "
"Call fit_anisotropy() first or provide U_aniso."
)
if self.F is None:
raise ValueError("No amplitude data loaded")
# Get s-vectors
s_vectors = self.get_scattering_vectors()
# Use raw tensors directly to preserve gradient flow
# (MaskedTensor doesn't support autograd operations)
F = self.F
sigma = self.F_sigma
# Apply correction
F_corrected = apply_anisotropy_correction(F, s_vectors, U_aniso)
sigma_F_corrected = apply_anisotropy_correction(
sigma, s_vectors, U_aniso
) if sigma is not None else None
return F_corrected, sigma_F_corrected
[docs]
def compute_e_values(
self,
n_shells: int = 20,
d_min: Optional[float] = None,
d_max: Optional[float] = None,
apply_anisotropy: bool = True,
fit_anisotropy: bool = True,
verbose: Optional[bool] = None,
) -> torch.Tensor:
"""
Compute E-values with optional anisotropy correction.
E-values are normalized structure factors where <E²> = 1 within each
resolution shell. Anisotropy correction can be applied first to account
for directional variation in diffraction.
Parameters
----------
n_shells : int
Number of resolution shells for normalization.
d_min : float, optional
High resolution limit in Angstroms. If None, uses dataset minimum.
d_max : float, optional
Low resolution limit in Angstroms. If None, uses dataset maximum.
apply_anisotropy : bool
If True, apply anisotropy correction before E-value calculation.
fit_anisotropy : bool
If True and apply_anisotropy is True, fit anisotropy parameters.
If False and apply_anisotropy is True, uses existing self.U_aniso.
verbose : bool, optional
Print progress. If None, uses self.verbose.
Returns
-------
E : torch.Tensor
E-values, shape (N,). Also stored in self.E.
self.E_squared is also populated with E² values.
Raises
------
ValueError
If no amplitude data is available.
Examples
--------
Compute E-values with automatic anisotropy correction::
data = ReflectionData().load_mtz('data.mtz')
E = data.compute_e_values(n_shells=30)
print(f"E-values: mean={E.mean():.3f}, std={E.std():.3f}")
Compute E-values without anisotropy correction::
E = data.compute_e_values(apply_anisotropy=False)
"""
from torchref.base import F_squared_to_E_values
if self.F is None:
raise ValueError("No amplitude data loaded")
if verbose is None:
verbose = self.verbose > 0
# Get resolution limits
if d_min is None:
d_min = self.get_max_res()
if d_max is None:
d_max = self.get_min_res()
# Get F² values (possibly with anisotropy correction)
if apply_anisotropy:
if fit_anisotropy:
self.fit_anisotropy(
n_shells=n_shells, d_min=d_min, d_max=d_max, verbose=verbose
)
F_squared = self.apply_anisotropy_correction()[0] ** 2
else:
F_squared = self.F**2
# Get s-vectors
s_vectors = self.get_scattering_vectors()
# Compute E-values
E, E_squared, shell_idx = F_squared_to_E_values(
F_squared, s_vectors, n_shells=n_shells, d_min=d_min, d_max=d_max
)
# Store results
self.E = E
self.E_squared = E_squared
self.radial_shell_indices = shell_idx
if verbose:
print(f"E-value statistics:")
print(f" E range: [{E.min():.3f}, {E.max():.3f}]")
print(f" E mean: {E.mean():.3f}, std: {E.std():.3f}")
print(f" E² mean: {E_squared.mean():.3f} (should be ~1.0)")
return E
[docs]
def setup_scale(self, scale: Optional[float] = None) -> float:
"""
Set overall scale factor, parametrized in log space.
Parameters
----------
scale : float, optional
If provided, sets the scale factor directly.
If None, computes scale to make mean F equal to 1.0.
Returns
-------
float
The scale factor applied.
"""
if scale is None:
self.log_scale = torch.tensor(0.0, device=self.device, requires_grad=False, dtype=dtypes.float)
else:
self.log_scale = torch.log(torch.tensor(scale, device=self.device, requires_grad=False, dtype=dtypes.float))
return self
[docs]
def get_corrected_data(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get scaled amplitude and uncertainty tensors.
Applies the current scale factor (from self.log_scale) to F and F_sigma.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Scaled F and F_sigma tensors. If log_scale is None, returns original.
"""
if not hasattr(self, "log_scale") or self.log_scale is None:
raise ValueError("Scale not set up. Call setup_scale() first.")
if not hasattr(self, "U_aniso") or self.U_aniso is None:
raise ValueError("Anisotropy not set up. Call setup_anisotropy() first.")
F_corrected, F_sigma_corrected = self.apply_anisotropy_correction()
scale_factor = torch.exp(self.log_scale)
F_scaled = F_corrected * scale_factor
F_sigma_scaled = F_sigma_corrected * scale_factor
return F_scaled, F_sigma_scaled
[docs]
def parameters(self) -> List[Parameter]:
"""
Get list of learnable parameters for optimization.
Returns
-------
iter[Parameter]
iter of torch Parameters to be optimized. Includes log_scale and U_aniso if set.
"""
params = []
if self.log_scale is not None:
params.append(self.log_scale)
if self.U_aniso is not None:
params.append(self.U_aniso)
return params