Source code for torchref.maps.difference_map

"""
Isomorphous difference map from two datasets.

Computes a difference Fourier map using DF = F_data - F_reference with
phases from a model, after scaling both datasets to a common reference.
"""

from __future__ import annotations

from typing import Optional, Tuple

import torch

from torchref.base.reciprocal.grid_operations import place_on_grid
from torchref.config import get_default_device
from torchref.io.datasets.collection import DatasetCollection
from torchref.maps.map import Map
from torchref.symmetry.reciprocal_symmetry import expand_hkl
from torchref.utils.device_resolution import resolve_device


[docs] class DifferenceMap(Map): """Isomorphous difference map between two datasets. Scales both datasets to a common reference using ``DatasetCollection``, then computes difference Fourier coefficients: ``DF * exp(i * phi_calc)`` where ``DF = F_data - F_reference``. Parameters ---------- data : ReflectionData Reflection data for the perturbed state (e.g., light, derivative). data_reference : ReflectionData Reflection data for the reference state (e.g., dark, native). model : ModelFT Model for computing phases. gridsize : tuple of int, optional Grid dimensions (nx, ny, nz). If None, determined automatically. """
[docs] def __init__(self, data, data_reference, model, gridsize=None, device: Optional[torch.device] = None): # Pin all three inputs onto one device before constructing the # DatasetCollection / super().__init__ — both consume tensors # from data.hkl / model and would otherwise inherit whichever # device they happened to land on. resolved = resolve_device(data, data_reference, model, device=device) self.data_reference = data_reference self.data_perturbed = data # Build collection and scale self._collection = DatasetCollection(verbose=0, device=str(resolved)) self._collection.add_dataset( "reference", data_reference, set_as_reference=True ) self._collection.add_dataset("perturbed", data) self._collection.scale() # Use reference dataset for cell, spacegroup, hkl via super().__init__ super().__init__( data=data_reference, model=model, gridsize=gridsize, map_type="Fcalc", # placeholder, calculate() is overridden device=resolved, )
[docs] def calculate(self) -> torch.Tensor: """Compute the isomorphous difference map. Returns ------- torch.Tensor 3D real-space difference map tensor. """ # Get scaled amplitudes (applies scale + anisotropy from scale()) F_ref_scaled, _ = self.data_reference.get_corrected_data() F_pert_scaled, _ = self.data_perturbed.get_corrected_data() # Combined mask: only use reflections valid in both datasets mask_combined = self.data_reference.masks() & self.data_perturbed.masks() hkl_asu = self.data_reference.hkl[mask_combined] fobs_ref = F_ref_scaled[mask_combined] fobs_pert = F_pert_scaled[mask_combined] # Expand to P1 without Friedel mates (expand_to_p1() would reset # scaling, so expand manually via expand_hkl) sg = self.data_reference.spacegroup or "P1" hkl_p1, orig_idx, _ = expand_hkl( hkl_asu, sg, include_friedel=False, remove_absences=True, device=hkl_asu.device, ) # Map scaled amplitudes to P1 (amplitudes are invariant under symmetry) fobs_ref_p1 = fobs_ref[orig_idx] fobs_pert_p1 = fobs_pert[orig_idx] delta_f_p1 = fobs_pert_p1 - fobs_ref_p1 # Compute Fcalc for P1 hkl (for phases) fcalc_p1 = self.model.get_structure_factor(hkl_p1) phi_calc = torch.angle(fcalc_p1) # Difference Fourier coefficients: delta_f * exp(i * phi_calc) coefficients_p1 = delta_f_p1 * torch.exp(1j * phi_calc) # Determine grid size if self.gridsize is not None: gridsize = self.gridsize else: gridsize = self._determine_gridsize() # Place on grid with Hermitian enforcement and FFT to real space grid = place_on_grid( hkl_p1, coefficients_p1, gridsize, enforce_hermitian=True ) self._map = torch.fft.fftn(grid, dim=(0, 1, 2), norm="forward").real return self._map