torchref.model.model_ft module
- class torchref.model.model_ft.ModelFT(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]
Bases:
CachedForwardMixin,ModelModel subclass for Fourier Transform-based electron density and structure factor calculations.
ModelFT extends the base Model class with capabilities for computing electron density maps in real space and structure factors via FFT. Uses ITC92 parametrization for electron density calculations.
- Parameters:
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation around each atom. Default is 4.0.
gridsize (tuple of int, optional) – Explicit grid size (nx, ny, nz). If None, computed from cell and max_res.
wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.
anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.
*args – Additional positional arguments passed to parent Model class.
**kwargs – Additional keyword arguments passed to parent Model class.
- gridsize
Grid dimensions (nx, ny, nz).
- Type:
- real_space_grid
Real-space coordinate grid with shape (nx, ny, nz, 3).
- Type:
- map
Computed electron density map.
- Type:
torch.Tensor or None
- map_symmetry
Symmetry operator for map calculations.
- Type:
Examples
Empty initialization for state_dict loading:
model = ModelFT() model.load_state_dict(torch.load('model.pt'))
File-based initialization:
model = ModelFT(max_res=1.5) model.load_pdb('structure.pdb')
- __init__(*args, max_res=1.0, radius_angstrom=4.0, gridsize=None, wavelength=1.0, anomalous_threshold=0.5, **kwargs)[source]
Initialize an empty ModelFT shell.
Creates a model shell ready for file loading via load_pdb()/load_cif() or state restoration via load_state_dict().
- Parameters:
max_res (float, optional) – Maximum resolution for grid spacing in Angstroms. Default is 1.0.
radius_angstrom (float, optional) – Radius in Angstroms for density calculation. Default is 4.0.
gridsize (tuple of int, optional) – Explicit grid size tuple (nx, ny, nz). If None, computed automatically.
wavelength (float or None, optional) – X-ray wavelength in Angstroms for anomalous scattering correction. Default is 1.0 (standard synchrotron, ~12.4 keV). Set to None to disable anomalous corrections entirely.
anomalous_threshold (float, optional) – Significance threshold for anomalous scattering in electrons. Atoms with |f'| > threshold or |f''| > threshold will have anomalous corrections applied. Default is 0.5.
*args – Passed to parent Model class.
**kwargs – Passed to parent Model class.
- property cell
Unit cell object with parameters [a, b, c, alpha, beta, gamma].
- property spacegroup
Space group object.
- select(selection)[source]
Return a new Model containing only atoms matching the Phenix-style selection.
Creates an independent copy of the model containing only the selected atoms. All tensor data (coordinates, ADPs, occupancies, etc.) and metadata are properly subsetted.
- Parameters:
selection (str) – Phenix-style selection string. Supports: - chain <id>: Select by chain (e.g., “chain A”) - resseq <num>: Select by residue number (e.g., “resseq 10”) - resseq <start>:<end>: Select residue range (e.g., “resseq 10:20”) - resname <name>: Select by residue name (e.g., “resname ALA”) - name <atom>: Select by atom name (e.g., “name CA”) - element <elem>: Select by element (e.g., “element C”) - altloc <id>: Select by alternate location (e.g., “altloc A”) - all: Select all atoms - not <selection>: Negate selection - <sel1> and <sel2>: Intersection - <sel1> or <sel2>: Union - Parentheses for grouping
- Returns:
New instance of the same class containing only selected atoms. If called on a subclass, returns an instance of that subclass.
- Return type:
- Raises:
RuntimeError – If the model has not been initialized.
ValueError – If selection syntax is invalid or no atoms are selected.
Examples
model = Model().load_pdb('structure.pdb') # Select chain A chain_a = model.select("chain A") # Select backbone atoms backbone = model.select("name CA or name C or name N or name O") # Select residues 10-50 of chain B region = model.select("chain B and resseq 10:50") # Select all except water no_water = model.select("not resname HOH") # Complex selection with parentheses complex_sel = model.select("chain A and (resname ALA or resname GLY)")
Notes
This method preserves the class type, so subclasses will return instances of themselves, not the base Model class.
- setup_gridsize(max_res=None)[source]
Compute optimal grid dimensions.
Delegates to FFT.compute_grid_size().
- Parameters:
max_res (float, optional) – Maximum resolution in Angstroms. If None, uses self.max_res.
- Returns:
Grid dimensions (nx, ny, nz) as int32 tensor.
- Return type:
- property A: Tensor
ITC92 A parameters (amplitudes) for all atoms.
- Returns:
A parameters with shape (n_atoms, 5).
- Return type:
- property B: Tensor
ITC92 B parameters (widths) for all atoms.
- Returns:
B parameters with shape (n_atoms, 5).
- Return type:
- get_iso()[source]
Get isotropic atoms with their ITC92 parameters.
- Returns:
xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).
adp (torch.Tensor) – Atomic displacement parameters (isotropic) with shape (n_atoms,).
occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- get_aniso()[source]
Get anisotropic atoms with their ITC92 parameters.
- Returns:
xyz (torch.Tensor) – Atomic coordinates with shape (n_atoms, 3).
u (torch.Tensor) – Anisotropic U parameters with shape (n_atoms, 6).
occupancy (torch.Tensor) – Occupancies with shape (n_atoms,).
A (torch.Tensor) – ITC92 A parameters (amplitudes) with shape (n_atoms, 5).
B (torch.Tensor) – ITC92 B parameters (widths) with shape (n_atoms, 5).
- setup_grid(max_res=None, gridsize=None)[source]
Setup real-space grid for electron density calculation.
Delegates to FFT.setup_grid() using the stored cell and spacegroup.
- get_radius(min_radius_Angstrom=4.0)[source]
Get the radius in voxels used for density calculation around each atom.
- build_complete_map(radius=None, apply_symmetry=True)[source]
Build electron density map from all atoms.
Uses get_iso() and get_aniso() to get atom data and constructs the complete electron density map.
- Parameters:
- Returns:
Electron density map with symmetry applied if requested.
- Return type:
- build_initial_map(apply_symmetry=True)[source]
Build electron density map from atomic parameters.
Delegates to FFT.build_density_map() using the model’s stored parameters.
- Parameters:
apply_symmetry (bool, optional) – If True, apply crystallographic symmetry to the map. Default is True.
- Returns:
Electron density map with shape (nx, ny, nz).
- Return type:
- save_map(filename)[source]
Save the electron density map to a CCP4 format file.
- Parameters:
filename (str) – Output filename for the map.
- Raises:
ValueError – If no map has been computed yet.
- rebuild_map(radius=None)[source]
Rebuild the density map from scratch.
Convenience method that clears and rebuilds everything.
- Parameters:
radius (int, optional) – Radius in voxels around each atom. If None, uses self.radius. If specified, overrides self.radius.
- Returns:
Rebuilt electron density map.
- Return type:
- get_structure_factor(hkl, recalc=False, apply_anomalous=True)[source]
Get structure factors for given hkl reflections.
Uses
CachedForwardMixinto cache the result and auto-invalidate when parameters change or a backward pass propagates through.- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
recalc (bool, optional) – If True, forces recalculation bypassing the cache. Default is False.
apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.
- Returns:
Complex structure factors with shape (n_reflections,).
- Return type:
Notes
- The complete scattering factor is:
f(s, λ) = f₀(s) + f’(λ) + i·f’’(λ)
where f₀ is the normal (Thomson) scattering factor computed via FFT, and f’/f’’ are the wavelength-dependent anomalous corrections.
Anomalous corrections are only computed for atoms where |f'| > anomalous_threshold or |f''| > anomalous_threshold.
- property fft
Access the SfFFT submodule.
- forward(hkl, apply_anomalous=True)[source]
Compute structure factors for given hkl.
This is called by the mixin’s
__call__which handles caching, backward-hook registration, and auto-invalidation.- Parameters:
hkl (torch.Tensor) – Miller indices with shape (n_reflections, 3).
apply_anomalous (bool, optional) – If True and wavelength is set, apply anomalous scattering corrections (f’ and f’’) for heavy atoms. Default is True.
- Returns:
Calculated complex structure factors with shape (n_reflections,).
- Return type:
- copy(detach=True)[source]
Create a deep copy of the ModelFT.
Creates a complete independent copy including all Model base class data, FFT submodule state (gridsize, real_space_grid, voxel_size, map_symmetry), ITC92 parametrization, and scalar attributes. Cache is reset to empty.
- Parameters:
detach (bool, optional) – If True, the copy’s parameters will be detached from the computation graph (default: True).
- Returns:
A new ModelFT instance with copied data.
- Return type:
Examples
model = ModelFT().load_pdb('structure.pdb') model_copy = model.copy() # model_copy is independent, changes won't affect model
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Return a dictionary containing the complete state of the ModelFT.
Extends parent Model.state_dict() with FT-specific parameters including max_res, radius_angstrom. Grid state is handled by the FFT submodule.
- classmethod create_from_state_dict(state_dict, device=device(type='cpu'), verbose=1, dtype_float=torch.float32)[source]
Create a fully initialized ModelFT from a state dictionary.
This is the recommended way to restore a ModelFT from a saved state. Creates an instance with properly initialized submodules, then loads the state.
- Parameters:
state_dict (dict) – State dictionary from torch.save(model.state_dict(), …).
device (torch.device, optional) – Device to place tensors on. Defaults to the configured device.current.
verbose (int, optional) – Verbosity level. Default is 1.
dtype_float (torch.dtype, optional) – Float dtype for tensors. Default is dtypes.float.
- Returns:
Fully initialized instance with restored state.
- Return type: