from typing import Optional, Tuple
import gemmi
import numpy as np
import torch
from torchref.base.fourier import fft, ifft
from torchref.model.sf_fft import SfFFT
from torchref.model.model import Model
from torchref.symmetry import SpaceGroup
from torchref.symmetry.map_symmetry import MapSymmetry
from torchref.config import dtypes, get_default_device, get_float_dtype
from torchref.utils.caching import CachedForwardMixin
[docs]
class ModelFT(CachedForwardMixin, Model):
"""
Model subclass for Fourier Transform-based electron density and structure factor calculations.
ModelFT extends the base Model class with capabilities for computing electron
density maps in real space and structure factors via FFT. Uses ITC92
parametrization for electron density calculations.
Parameters
----------
max_res : float, optional
Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom : float, optional
Radius in Angstroms for density calculation around each atom. Default is 4.0.
gridsize : tuple of int, optional
Explicit grid size (nx, ny, nz). If None, computed from cell and max_res.
wavelength : float or None, optional
X-ray wavelength in Angstroms for anomalous scattering correction.
Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to
disable anomalous corrections entirely.
anomalous_threshold : float, optional
Significance threshold for anomalous scattering in electrons.
Atoms with |f'| > threshold or |f''| > threshold will have
anomalous corrections applied. Default is 0.5.
*args
Additional positional arguments passed to parent Model class.
**kwargs
Additional keyword arguments passed to parent Model class.
Attributes
----------
max_res : float
Maximum resolution for grid spacing.
radius_angstrom : float
Radius for density calculation.
wavelength : float or None
X-ray wavelength for anomalous scattering corrections.
anomalous_threshold : float
Threshold for significant anomalous scattering (electrons).
gridsize : torch.Tensor
Grid dimensions (nx, ny, nz).
real_space_grid : torch.Tensor
Real-space coordinate grid with shape (nx, ny, nz, 3).
map : torch.Tensor or None
Computed electron density map.
parametrization : dict
ITC92 parametrization dictionary {element: (A, B, C)}.
map_symmetry : MapSymmetry
Symmetry operator for map calculations.
Examples
--------
Empty initialization for state_dict loading::
model = ModelFT()
model.load_state_dict(torch.load('model.pt'))
File-based initialization::
model = ModelFT(max_res=1.5)
model.load_pdb('structure.pdb')
"""
[docs]
def __init__(
self,
*args,
max_res=1.0,
radius_angstrom=4.0,
gridsize: Optional[Tuple[int, int, int]] = None,
wavelength: Optional[float] = 1.0,
anomalous_threshold: float = 0.5,
**kwargs,
):
"""
Initialize an empty ModelFT shell.
Creates a model shell ready for file loading via load_pdb()/load_cif()
or state restoration via load_state_dict().
Parameters
----------
max_res : float, optional
Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom : float, optional
Radius in Angstroms for density calculation. Default is 4.0.
gridsize : tuple of int, optional
Explicit grid size tuple (nx, ny, nz). If None, computed automatically.
wavelength : float or None, optional
X-ray wavelength in Angstroms for anomalous scattering correction.
Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to
disable anomalous corrections entirely.
anomalous_threshold : float, optional
Significance threshold for anomalous scattering in electrons.
Atoms with |f'| > threshold or |f''| > threshold will have
anomalous corrections applied. Default is 0.5.
*args
Passed to parent Model class.
**kwargs
Passed to parent Model class.
"""
super().__init__(*args, **kwargs)
# FT-specific configuration
self.max_res = max_res
self.radius_angstrom = radius_angstrom
self._explicit_gridsize = gridsize
# Anomalous scattering configuration
self.wavelength = wavelength
self.anomalous_threshold = anomalous_threshold
self._anomalous_cache = None # Will hold (mask, f_prime, f_double_prime)
self._anomalous_elements_hash = None # Hash of element list for cache invalidation
self._fft = None
@property
def cell(self):
"""Unit cell object with parameters [a, b, c, alpha, beta, gamma]."""
return self._cell
@cell.setter
def cell(self, value):
"""
Set the unit cell and initialize FFT if spacegroup is also set.
Parameters
----------
value : Cell
The unit cell object to set.
"""
self._cell = value
self._maybe_initialize_fft()
@property
def spacegroup(self):
"""Space group object."""
return self._spacegroup
@spacegroup.setter
def spacegroup(self, value):
"""
Set the space group and initialize FFT if cell is also set.
Parameters
----------
value : SpaceGroup, gemmi.SpaceGroup, str, or int
The space group to set.
"""
if value is not None:
self._spacegroup = SpaceGroup(
value, dtype=self.dtype_float, device=self.device
)
else:
self._spacegroup = None
self._maybe_initialize_fft()
def _maybe_initialize_fft(self):
"""
Initialize SfFFT module if both cell and spacegroup are set.
This method is called by the cell and spacegroup setters to ensure
the SfFFT module is properly configured when both crystallographic
parameters are available.
"""
if self._cell is not None and self._spacegroup is not None:
self._fft = SfFFT(
cell=self._cell,
spacegroup=self._spacegroup,
device=self.device,
max_res=self.max_res,
radius_angstrom=self.radius_angstrom,
)
[docs]
def load_pdb(self, filename):
"""
Load a PDB file and initialize the model with FT-specific setup.
Parameters
----------
filename : str
Path to the PDB file.
Returns
-------
ModelFT
Self, for method chaining.
"""
super().load_pdb(filename)
# FFT is now initialized via cell/spacegroup setters in parent load()
self.setup_grid()
return self
[docs]
def select(self, selection):
selection = super().select(selection)
selection._build_parametrization()
# FFT is initialized via cell/spacegroup setters in parent select()
selection.setup_grid()
return selection
[docs]
def load_cif(self, filename):
"""
Load a CIF file and initialize the model with FT-specific setup.
Parameters
----------
filename : str
Path to the CIF/mmCIF file.
Returns
-------
ModelFT
Self, for method chaining.
"""
super().load_cif(filename)
self._build_parametrization()
# FFT is now initialized via cell/spacegroup setters in parent load()
self.setup_grid()
return self
[docs]
def setup_gridsize(self, max_res=None):
"""
Compute optimal grid dimensions.
Delegates to FFT.compute_grid_size().
Parameters
----------
max_res : float, optional
Maximum resolution in Angstroms. If None, uses self.max_res.
Returns
-------
torch.Tensor
Grid dimensions (nx, ny, nz) as int32 tensor.
"""
if max_res is not None:
self.max_res = max_res
self._fft.max_res = max_res
if self.verbose > 1:
print(f"Defining grid size for ={self.max_res} Å")
# Use Cell's compute_grid_size method
gridsize = self.cell.compute_grid_size(self.max_res)
return torch.tensor(gridsize, dtype=dtypes.int, device=self.device)
def _build_parametrization(self):
"""
Build ITC92 parametrization for all atoms in the model.
Delegates to parent Model class which handles the actual parametrization
building. This method exists for API compatibility.
"""
# Use parent's implementation
return super()._build_parametrization()
# =========================================================================
# Backward-compatible properties for scattering parameters
# =========================================================================
@property
def A(self) -> torch.Tensor:
"""
ITC92 A parameters (amplitudes) for all atoms.
Returns
-------
torch.Tensor
A parameters with shape (n_atoms, 5).
"""
self._build_parametrization()
return self._A
@property
def B(self) -> torch.Tensor:
"""
ITC92 B parameters (widths) for all atoms.
Returns
-------
torch.Tensor
B parameters with shape (n_atoms, 5).
"""
self._build_parametrization()
return self._B
# =========================================================================
# Backward-compatible properties for FFT grid attributes
# =========================================================================
@property
def gridsize(self) -> Optional[torch.Tensor]:
"""Grid dimensions (nx, ny, nz)."""
return self._fft.gridsize
@gridsize.setter
def gridsize(self, value):
"""Set grid size (for backward compatibility)."""
self._fft.gridsize = value
@property
def real_space_grid(self) -> Optional[torch.Tensor]:
"""Real-space coordinate grid with shape (nx, ny, nz, 3)."""
return self._fft.real_space_grid
@real_space_grid.setter
def real_space_grid(self, value):
"""Set real space grid (for backward compatibility)."""
self._fft.real_space_grid = value
@property
def voxel_size(self) -> Optional[torch.Tensor]:
"""Voxel dimensions."""
return self._fft.voxel_size
@voxel_size.setter
def voxel_size(self, value):
"""Set voxel size (for backward compatibility)."""
self._fft.voxel_size = value
@property
def map_symmetry(self) -> Optional[MapSymmetry]:
"""Symmetry operator for map calculations."""
return self._fft.map_symmetry
@map_symmetry.setter
def map_symmetry(self, value):
"""Set map symmetry (for backward compatibility)."""
self._fft.map_symmetry = value
[docs]
def get_iso(self):
"""
Get isotropic atoms with their ITC92 parameters.
Returns
-------
xyz : torch.Tensor
Atomic coordinates with shape (n_atoms, 3).
adp : torch.Tensor
Atomic displacement parameters (isotropic) with shape (n_atoms,).
occupancy : torch.Tensor
Occupancies with shape (n_atoms,).
A : torch.Tensor
ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B : torch.Tensor
ITC92 B parameters (widths) with shape (n_atoms, 5).
"""
# Get base isotropic data from parent
xyz, adp, occupancy = super().get_iso()
# Get scattering parameters from parent
A, B = self.get_scattering_params_iso()
return xyz, adp, occupancy, A, B
[docs]
def get_aniso(self):
"""
Get anisotropic atoms with their ITC92 parameters.
Returns
-------
xyz : torch.Tensor
Atomic coordinates with shape (n_atoms, 3).
u : torch.Tensor
Anisotropic U parameters with shape (n_atoms, 6).
occupancy : torch.Tensor
Occupancies with shape (n_atoms,).
A : torch.Tensor
ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B : torch.Tensor
ITC92 B parameters (widths) with shape (n_atoms, 5).
"""
# Get base anisotropic data from parent
xyz, u, occupancy = super().get_aniso()
# Get scattering parameters from parent
A, B = self.get_scattering_params_aniso()
return xyz, u, occupancy, A, B
[docs]
def setup_grid(self, max_res=None, gridsize=None):
"""
Setup real-space grid for electron density calculation.
Delegates to FFT.setup_grid() using the stored cell and spacegroup.
Parameters
----------
max_res : float, optional
Maximum resolution for grid spacing in Angstroms.
If None, uses self.max_res.
gridsize : tuple of int, optional
Explicit grid size (nx, ny, nz). If None, computed automatically
using Cell.compute_grid_size() and SpaceGroup.suggest_grid_size().
"""
if max_res is not None:
self.max_res = max_res
self._fft.max_res = max_res
if self.verbose > 1:
print(f"Setting up grids with max_res={self.max_res} Å")
# Determine grid size to use
gridsize_to_use = gridsize or self._explicit_gridsize
# Delegate to FFT submodule (which now uses stored cell/spacegroup)
self._fft.setup_grid(
gridsize=gridsize_to_use,
max_res=self.max_res,
)
if self.verbose > 2:
print(f"Grid shape: {self._fft.real_space_grid.shape[:-1]}")
print(f"Voxel size: {self._fft.voxel_size}")
[docs]
def get_radius(self, min_radius_Angstrom: float = 4.0):
"""
Get the radius in voxels used for density calculation around each atom.
Parameters
----------
min_radius_Angstrom : float, optional
Minimum radius in Angstroms. Default is 4.0.
Returns
-------
int
Radius in voxels.
"""
if not hasattr(self, "real_space_grid") or self.real_space_grid is None:
self.setup_grid()
voxel_size = self.real_space_grid[1, 1, 1] - self.real_space_grid[0, 0, 0]
min_radius = (
torch.ceil(min_radius_Angstrom / torch.min(voxel_size))
.to(dtypes.int)
.item()
)
if self.verbose > 1:
print(
f"Calculated radius for density calculation: {min_radius} voxels (voxel size: {voxel_size}), this corresponds to at least {min_radius_Angstrom} Å"
)
return min_radius
[docs]
def build_complete_map(self, radius=None, apply_symmetry=True):
"""
Build electron density map from all atoms.
Uses get_iso() and get_aniso() to get atom data and constructs
the complete electron density map.
Parameters
----------
radius : int, optional
Radius in voxels around each atom to compute density.
If None, uses self.radius.
apply_symmetry : bool, optional
If True and space group is not P1, apply symmetry operations
to the map. Default is True.
Returns
-------
torch.Tensor
Electron density map with symmetry applied if requested.
"""
self.map = self.build_initial_map(apply_symmetry=apply_symmetry)
if self.verbose > 2:
print(
f"Density map built. Sum: {self.map.sum():.2f}, Max: {self.map.max():.4f}"
)
return self.map
[docs]
def build_initial_map(self, apply_symmetry=True):
"""
Build electron density map from atomic parameters.
Delegates to FFT.build_density_map() using the model's stored parameters.
Parameters
----------
apply_symmetry : bool, optional
If True, apply crystallographic symmetry to the map. Default is True.
Returns
-------
torch.Tensor
Electron density map with shape (nx, ny, nz).
"""
if self._fft.real_space_grid is None:
self.setup_grid()
if self.verbose > 2:
print(
f"Building density map with radius={self.radius_angstrom} angstrom..."
)
# Get isotropic atoms
xyz_iso, adp_iso, occ_iso, A_iso, B_iso = self.get_iso()
if self.verbose > 3:
assert torch.all(
torch.isfinite(A_iso)
), "Non-finite values found in A_iso during map building."
assert torch.all(
torch.isfinite(B_iso)
), "Non-finite values found in B_iso during map building."
assert torch.all(
torch.isfinite(xyz_iso)
), "Non-finite values found in xyz_iso during map building."
assert torch.all(
torch.isfinite(adp_iso)
), "Non-finite values found in adp_iso during map building."
assert torch.all(
torch.isfinite(occ_iso)
), "Non-finite values found in occ_iso during map building."
# Get anisotropic atoms
xyz_aniso, u_aniso, occ_aniso, A_aniso, B_aniso = self.get_aniso()
# Delegate to FFT submodule
self.map = self._fft.build_density_map(
xyz_iso=xyz_iso,
adp_iso=adp_iso,
occ_iso=occ_iso,
A_iso=A_iso,
B_iso=B_iso,
xyz_aniso=xyz_aniso if len(xyz_aniso) > 0 else None,
u_aniso=u_aniso if len(xyz_aniso) > 0 else None,
occ_aniso=occ_aniso if len(xyz_aniso) > 0 else None,
A_aniso=A_aniso if len(xyz_aniso) > 0 else None,
B_aniso=B_aniso if len(xyz_aniso) > 0 else None,
apply_symmetry=apply_symmetry,
)
if self.verbose > 3:
assert torch.all(
torch.isfinite(self.map)
), "Non-finite values found in map."
return self.map
[docs]
def save_map(self, filename):
"""
Save the electron density map to a CCP4 format file.
Parameters
----------
filename : str
Output filename for the map.
Raises
------
ValueError
If no map has been computed yet.
"""
if self.map is None:
raise ValueError("No map to save. Call build_density_map() first.")
np_map = self.map.detach().cpu().numpy().astype(np.float32)
cell = self.cell.tolist()
if self.verbose > 0:
print(f"Saving map to {filename}")
print(f" Map shape: {self.map.shape}")
print(f" Map sum: {self.map.sum():.2f}")
print(f" Map range: [{self.map.min():.4f}, {self.map.max():.4f}]")
map_ccp = gemmi.Ccp4Map()
map_ccp.grid = gemmi.FloatGrid(
np_map, gemmi.UnitCell(*cell), SpaceGroup("P1")._gemmi
)
map_ccp.setup(0.0)
map_ccp.update_ccp4_header()
map_ccp.write_ccp4_map(filename)
if self.verbose > 0:
print("Map saved successfully")
[docs]
def get_map_statistics(self):
"""Get statistics about the current density map."""
if self.map is None:
return None
stats = {
"shape": self.map.shape,
"sum": float(self.map.sum()),
"mean": float(self.map.mean()),
"std": float(self.map.std()),
"min": float(self.map.min()),
"max": float(self.map.max()),
"n_positive": int((self.map > 0).sum()),
"n_negative": int((self.map < 0).sum()),
}
return stats
[docs]
def rebuild_map(self, radius=None):
"""
Rebuild the density map from scratch.
Convenience method that clears and rebuilds everything.
Parameters
----------
radius : int, optional
Radius in voxels around each atom. If None, uses self.radius.
If specified, overrides self.radius.
Returns
-------
torch.Tensor
Rebuilt electron density map.
"""
if self.verbose > 1:
print("Rebuilding density map from scratch...")
return self.build_density_map(radius=radius)
[docs]
def update_pdb(self):
"""
Update PDB with current atomic parameters.
"""
return super().update_pdb()
[docs]
def reset_cache(self):
"""Reset SF cache, anomalous cache, and all wrapper forward caches."""
self.reset_forward_cache()
# Drop the anomalous scattering cache; it is recomputed on next use
# and would otherwise hold tensors on the previous device.
self._anomalous_cache = None
self._anomalous_elements_hash = None
for module in self.children():
if hasattr(module, "reset_forward_cache"):
module.reset_forward_cache()
[docs]
def invalidate_cache(self):
"""Alias for ``reset_cache()``."""
self.reset_cache()
# =========================================================================
# Anomalous Scattering Correction Methods
# =========================================================================
def _get_anomalous_cache(
self,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Lazily compute and cache anomalous correction data.
Returns mask, f_prime, f_double_prime for significant atoms.
Recomputes if element list changes.
Returns
-------
mask : torch.Tensor
Boolean mask of shape (n_atoms,) - True for atoms needing correction
f_prime : torch.Tensor
f' values for significant atoms only (n_significant,)
f_double_prime : torch.Tensor
f'' values for significant atoms only (n_significant,)
"""
from torchref.base.scattering.anomalous_table import (
get_significant_elements,
get_anomalous_corrections_by_indices,
)
# Get element list from PDB
element_list = self.pdb["element"].tolist()
# Hash current element list
elements_hash = hash(tuple(element_list))
if (
self._anomalous_cache is None
or self._anomalous_elements_hash != elements_hash
):
# Find significant elements at current wavelength
unique_elements = list(set(element_list))
significant = get_significant_elements(
unique_elements, self.wavelength, self.anomalous_threshold
)
if self.verbose > 1 and significant:
print(
f"Anomalous scatterers at {self.wavelength:.4f} Å: "
f"{list(significant.keys())}"
)
# Get corrections for all atoms
mask, f_prime, f_double_prime = get_anomalous_corrections_by_indices(
element_list, significant, self.device, self.dtype_float
)
# Pre-compute integer indices to avoid boolean indexing GPU sync
has_anomalous = bool(mask.any().item())
anomalous_indices = mask.nonzero(as_tuple=True)[0] if has_anomalous else None
self._anomalous_cache = (mask, f_prime, f_double_prime, has_anomalous, anomalous_indices)
self._anomalous_elements_hash = elements_hash
return self._anomalous_cache
def _apply_anomalous_correction(
self,
sf: torch.Tensor,
hkl: torch.Tensor,
) -> torch.Tensor:
"""
Apply anomalous scattering correction to structure factors.
The correction adds the contribution from anomalous scattering:
ΔF(h) = Σ_significant (f' + if'') × exp(2πi h·r) × occ
Only computed for atoms where |f'| > threshold or |f''| > threshold.
Parameters
----------
sf : torch.Tensor
Complex structure factors from FFT with shape (n_reflections,)
hkl : torch.Tensor
Miller indices with shape (n_reflections, 3)
Returns
-------
torch.Tensor
Corrected complex structure factors with shape (n_reflections,)
"""
mask, f_prime, f_double_prime, has_anomalous, anomalous_indices = self._get_anomalous_cache()
if not has_anomalous:
return sf # No significant anomalous scatterers
# Get fractional coordinates and occupancies for significant atoms only
# Uses pre-computed integer indices to avoid boolean indexing GPU sync
xyz_frac = self.xyz_fractional()[anomalous_indices] # (n_significant, 3)
occ = self.occupancy()[anomalous_indices] # (n_significant,)
# Compute phase factors: exp(2πi h·r)
# h·r is the dot product of hkl with fractional coordinates
h_dot_r = torch.matmul(
hkl.to(dtype=self.dtype_float, device=xyz_frac.device), xyz_frac.T
) # (n_refl, n_significant)
phase = 2 * torch.pi * h_dot_r
cos_phase = torch.cos(phase)
sin_phase = torch.sin(phase)
# Anomalous contribution weighted by occupancy
# f' contributes to both real and imaginary parts
# f'' contributes with a phase shift (it's the imaginary part of f)
f_prime_occ = f_prime * occ # (n_significant,)
f_double_prime_occ = f_double_prime * occ # (n_significant,)
# For each reflection:
# Real part: Σ [f'·cos(φ) - f''·sin(φ)] × occ
# Imag part: Σ [f'·sin(φ) + f''·cos(φ)] × occ
delta_real = torch.sum(
f_prime_occ * cos_phase - f_double_prime_occ * sin_phase, dim=-1
)
delta_imag = torch.sum(
f_prime_occ * sin_phase + f_double_prime_occ * cos_phase, dim=-1
)
return sf + torch.complex(delta_real, delta_imag)
[docs]
def get_structure_factor(
self, hkl: torch.Tensor, recalc=False, apply_anomalous: bool = True
) -> torch.Tensor:
"""
Get structure factors for given hkl reflections.
Uses ``CachedForwardMixin`` to cache the result and auto-invalidate
when parameters change or a backward pass propagates through.
Parameters
----------
hkl : torch.Tensor
Miller indices with shape (n_reflections, 3).
recalc : bool, optional
If True, forces recalculation bypassing the cache.
Default is False.
apply_anomalous : bool, optional
If True and wavelength is set, apply anomalous scattering
corrections (f' and f'') for heavy atoms. Default is True.
Returns
-------
torch.Tensor
Complex structure factors with shape (n_reflections,).
Notes
-----
The complete scattering factor is:
f(s, λ) = f₀(s) + f'(λ) + i·f''(λ)
where f₀ is the normal (Thomson) scattering factor computed via FFT,
and f'/f'' are the wavelength-dependent anomalous corrections.
Anomalous corrections are only computed for atoms where
|f'| > anomalous_threshold or |f''| > anomalous_threshold.
"""
return self(hkl, recalc=recalc, apply_anomalous=apply_anomalous)
@property
def fft(self):
"""Access the SfFFT submodule."""
if self._fft is None:
self._maybe_initialize_fft()
return self._fft
[docs]
def forward(
self, hkl, apply_anomalous: bool = True
) -> torch.Tensor:
"""
Compute structure factors for given hkl.
This is called by the mixin's ``__call__`` which handles caching,
backward-hook registration, and auto-invalidation.
Parameters
----------
hkl : torch.Tensor
Miller indices with shape (n_reflections, 3).
apply_anomalous : bool, optional
If True and wavelength is set, apply anomalous scattering
corrections (f' and f'') for heavy atoms. Default is True.
Returns
-------
torch.Tensor
Calculated complex structure factors with shape (n_reflections,).
"""
sf, self.ed = self.fft.compute_structure_factors(
hkl,
*self.get_iso(),
*self.get_aniso(),
apply_symmetry=True,
)
# Apply anomalous correction as post-processing
if apply_anomalous and self.wavelength is not None:
sf = self._apply_anomalous_correction(sf, hkl)
if self.verbose > 2:
assert torch.all(
torch.isfinite(sf)
), "Non-finite values found while calculating fcalc."
return sf
[docs]
def copy(self, detach: bool = True) -> "ModelFT":
"""
Create a deep copy of the ModelFT.
Creates a complete independent copy including all Model base class data,
FFT submodule state (gridsize, real_space_grid, voxel_size, map_symmetry),
ITC92 parametrization, and scalar attributes.
Cache is reset to empty.
Parameters
----------
detach : bool, optional
If True, the copy's parameters will be detached from the
computation graph (default: True).
Returns
-------
ModelFT
A new ModelFT instance with copied data.
Examples
--------
::
model = ModelFT().load_pdb('structure.pdb')
model_copy = model.copy()
# model_copy is independent, changes won't affect model
"""
if not self.initialized:
raise RuntimeError("Cannot copy an uninitialized ModelFT. Load data first.")
# Create new ModelFT instance with same configuration
model_copy = ModelFT(
dtype_float=self.dtype_float,
verbose=self.verbose,
device=self.device,
strip_H=self.strip_H,
max_res=self.max_res,
radius_angstrom=self.radius_angstrom,
gridsize=self._explicit_gridsize,
wavelength=self.wavelength,
anomalous_threshold=self.anomalous_threshold,
)
# Deep copy the PDB DataFrame
model_copy.pdb = self.pdb.copy(deep=True)
# Copy spacegroup using its copy() method
if self._spacegroup is not None:
model_copy._spacegroup = self._spacegroup.copy()
else:
model_copy._spacegroup = None
model_copy.initialized = True
# Copy Cell object using its clone() method
if self.cell is not None:
model_copy.cell = self.cell.clone()
# Copy all registered buffers using PyTorch's _buffers dict
# (excluding FFT submodule buffers which are handled separately)
for buffer_name, buffer_value in self._buffers.items():
if buffer_value is not None:
if detach:
model_copy.register_buffer(buffer_name, buffer_value.clone().detach())
else:
model_copy.register_buffer(buffer_name, buffer_value.clone())
# Copy all modules (parameter wrappers) using their .copy() methods
# Skip _fft and _spacegroup as they are handled separately
skip_modules = {"_fft", "_spacegroup", "spacegroup", "_symmetry", "symmetry"}
for module_name, module in self._modules.items():
if module_name in skip_modules:
continue
if module is not None and hasattr(module, "copy"):
setattr(model_copy, module_name, module.copy())
# Copy alternative conformation pairs
if hasattr(self, "altloc_pairs") and self.altloc_pairs:
model_copy.altloc_pairs = [
tuple(tensor.clone() for tensor in group) for group in self.altloc_pairs
]
else:
model_copy.altloc_pairs = []
# Copy FT-specific attributes: _parametrization dict
if hasattr(self, "_parametrization") and self._parametrization is not None:
import copy as copy_module
model_copy._parametrization = copy_module.deepcopy(self._parametrization)
# Copy FFT submodule using its copy() method
if self._fft is not None:
model_copy._fft = self._fft.copy()
# Setup grid if it was set up in the original
if self._fft.real_space_grid is not None:
model_copy.setup_grid(max_res=self.max_res)
# Reset cache on the copy (don't share cached structure factors)
model_copy.reset_cache()
if self.verbose > 0:
print(f"✓ ModelFT copied successfully ({len(model_copy.pdb)} atoms)")
return model_copy
[docs]
def state_dict(self, destination=None, prefix="", keep_vars=False):
"""
Return a dictionary containing the complete state of the ModelFT.
Extends parent Model.state_dict() with FT-specific parameters including
max_res, radius_angstrom. Grid state is handled by the FFT submodule.
Parameters
----------
destination : dict, optional
Optional dict to populate.
prefix : str, optional
Prefix for parameter names. Default is ''.
keep_vars : bool, optional
Whether to keep variables in computational graph. Default is False.
Returns
-------
dict
Complete state dictionary.
"""
# Get parent Model state_dict (includes _A, _B buffers and FFT submodule)
state = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
# Add ModelFT-specific state
state[prefix + "max_res"] = self.max_res
state[prefix + "radius_angstrom"] = self.radius_angstrom
state[prefix + "wavelength"] = self.wavelength
state[prefix + "anomalous_threshold"] = self.anomalous_threshold
# Note: FFT submodule state (gridsize, real_space_grid, voxel_size) is
# automatically included via PyTorch's module serialization with _fft. prefix
# _parametrization dict is not saved as it can be rebuilt from _A, _B buffers
# _cache is not saved as it should be rebuilt
# _anomalous_cache is not saved as it can be rebuilt from element list
return state
[docs]
@classmethod
def create_from_state_dict(
cls,
state_dict: dict,
device: torch.device = get_default_device(),
verbose: int = 1,
dtype_float: torch.dtype = get_float_dtype(),
) -> "ModelFT":
"""
Create a fully initialized ModelFT from a state dictionary.
This is the recommended way to restore a ModelFT from a saved state.
Creates an instance with properly initialized submodules, then loads the state.
Parameters
----------
state_dict : dict
State dictionary from torch.save(model.state_dict(), ...).
device : torch.device, optional
Device to place tensors on. Defaults to the configured device.current.
verbose : int, optional
Verbosity level. Default is 1.
dtype_float : torch.dtype, optional
Float dtype for tensors. Default is dtypes.float.
Returns
-------
ModelFT
Fully initialized instance with restored state.
"""
from torchref.symmetry import SpaceGroup
# Extract ModelFT-specific metadata
max_res = state_dict.pop("max_res", 1.0)
radius_angstrom = state_dict.pop("radius_angstrom", 4.0)
wavelength = state_dict.pop("wavelength", 1.0)
anomalous_threshold = state_dict.pop("anomalous_threshold", 0.5)
# Extract Model metadata
pdb = state_dict.pop("pdb", None)
spacegroup_str = state_dict.pop("spacegroup", None)
cell_tensor = state_dict.pop("cell", None)
initialized = state_dict.pop("initialized", False)
saved_dtype = state_dict.pop("dtype_float", dtypes.float)
state_dict.pop("device", None) # Remove but don't use (use provided device)
strip_H = state_dict.pop("strip_H", True)
altloc_pairs = state_dict.pop("altloc_pairs", [])
# Extract grid info from FFT submodule state
# Note: FFT buffers are prefixed with "_fft."
gridsize = state_dict.pop("_fft.gridsize", None)
# Also try old-style keys for backward compatibility
if gridsize is None:
gridsize = state_dict.pop("gridsize", None)
# Create instance with FT-specific params
instance = cls(
dtype_float=saved_dtype,
verbose=verbose,
device=device,
strip_H=strip_H,
max_res=max_res,
radius_angstrom=radius_angstrom,
wavelength=wavelength,
anomalous_threshold=anomalous_threshold,
)
# Set metadata
instance.pdb = pdb
instance.initialized = initialized
instance.altloc_pairs = altloc_pairs
# Setup spacegroup - setter also sets symmetry automatically
instance.spacegroup = spacegroup_str
# Create Cell object - setter will initialize FFT since spacegroup is already set
from torchref.symmetry import Cell
if cell_tensor is not None:
instance.cell = Cell(cell_tensor, dtype=saved_dtype, device=device)
# If PDB exists, create the parameter wrappers with correct shapes
if pdb is not None:
from torchref.model.parameter_wrappers import (
MixedTensor,
OccupancyTensor,
PositiveMixedTensor,
)
n_atoms = len(pdb)
# Create MixedTensors
xyz_mask = state_dict.get("xyz.refinable_mask")
b_mask = state_dict.get("b.refinable_mask")
u_mask = state_dict.get("u.refinable_mask")
instance.xyz = MixedTensor(
torch.tensor(pdb[["x", "y", "z"]].values, dtype=saved_dtype),
refinable_mask=xyz_mask,
name="xyz",
)
instance.b = PositiveMixedTensor(
torch.tensor(pdb["tempfactor"].values, dtype=saved_dtype),
refinable_mask=b_mask,
name="b_factor",
)
instance.u = MixedTensor(
torch.tensor(
pdb[["u11", "u22", "u33", "u12", "u13", "u23"]].values,
dtype=saved_dtype,
),
refinable_mask=u_mask,
name="aniso_U",
)
# Create OccupancyTensor
initial_occ = torch.tensor(pdb["occupancy"].values, dtype=saved_dtype)
sharing_groups, altloc_groups, refinable_mask = (
instance._create_occupancy_groups(pdb, initial_occ)
)
saved_occ_mask = state_dict.get("occupancy.refinable_mask")
if saved_occ_mask is not None:
if saved_occ_mask.device != sharing_groups.device:
saved_occ_mask = saved_occ_mask.to(sharing_groups.device)
refinable_mask = saved_occ_mask[sharing_groups]
instance.occupancy = OccupancyTensor(
initial_values=initial_occ,
sharing_groups=sharing_groups,
altloc_groups=altloc_groups,
refinable_mask=refinable_mask,
dtype=saved_dtype,
device=device,
name="occupancy",
)
# Register aniso_flag buffer
if "aniso_flag" not in instance._buffers or instance.aniso_flag is None:
instance.register_buffer(
"aniso_flag",
torch.tensor(pdb["anisou_flag"].values, dtype=torch.bool),
)
# Register mask buffers
instance.register_buffer(
"xyz_mask", torch.ones(n_atoms, dtype=torch.bool, device=device)
)
instance.register_buffer(
"b_mask", torch.ones(n_atoms, dtype=torch.bool, device=device)
)
instance.register_buffer(
"u_mask", torch.ones(n_atoms, dtype=torch.bool, device=device)
)
instance.register_buffer(
"occupancy_mask", torch.ones(n_atoms, dtype=torch.bool, device=device)
)
# Register vdw_radii if present
if "vdw_radii" in state_dict and state_dict["vdw_radii"] is not None:
instance.register_buffer(
"vdw_radii", torch.zeros_like(state_dict["vdw_radii"], device=device)
)
# Handle _A and _B buffers (scattering parameters)
# Check both old-style (A, B) and new-style (_A, _B) keys
a_key = "_A" if "_A" in state_dict else "A" if "A" in state_dict else None
b_key = "_B" if "_B" in state_dict else "B" if "B" in state_dict else None
if a_key and state_dict[a_key] is not None:
instance.register_buffer(
"_A", torch.zeros_like(state_dict[a_key], device=device)
)
if b_key and state_dict[b_key] is not None:
instance.register_buffer(
"_B", torch.zeros_like(state_dict[b_key], device=device)
)
# Setup grid via FFT submodule
if gridsize is not None and cell_tensor is not None:
if isinstance(gridsize, torch.Tensor):
gs_tuple = tuple(int(x) for x in gridsize.tolist())
else:
gs_tuple = tuple(int(x) for x in gridsize)
instance.setup_grid(gridsize=gs_tuple)
# Filter state_dict and load
# Remap old-style A/B keys to new _A/_B keys
filtered_state_dict = {}
for k, v in state_dict.items():
if not hasattr(v, 'shape') or v.shape[0] > 0:
# Remap old keys to new keys
if k == "A":
filtered_state_dict["_A"] = v
elif k == "B":
filtered_state_dict["_B"] = v
else:
filtered_state_dict[k] = v
instance.load_state_dict(filtered_state_dict, strict=False)
# Reset cache
instance.reset_cache()
if verbose > 0:
n_atoms = len(instance.pdb) if instance.pdb is not None else 0
print(f"Created ModelFT from state_dict: {n_atoms} atoms")
return instance