torchref.scaling.solvent module
A class for modelling solvent contribution to structure factors.
- class torchref.scaling.solvent.SolventModel(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]
Bases:
DeviceMixin,DebugMixin,ModuleSolventModel to compute solvent contribution to structure factors using Phenix-like approach.
Supports two initialization patterns:
Empty initialization (for state_dict loading):
solvent = SolventModel() # Creates empty shell solvent.load_state_dict(torch.load('solvent.pt'))
Full initialization with model:
solvent = SolventModel(model, k_solvent=0.35, b_solvent=46.0)
- device
Device for tensor operations.
- Type:
- float_type
Floating point data type.
- Type:
- log_k_solvent
Log of solvent scattering scale factor.
- Type:
torch.nn.Parameter
- b_solvent
Solvent B-factor.
- Type:
torch.nn.Parameter
- __init__(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]
Initialize SolventModel.
If model is provided, fully initializes the solvent model. If not provided (empty init), creates a shell ready for load_state_dict().
- Parameters:
model (ModelFT, optional) – The atomic model used for structure factor calculations (optional for empty init).
radius (float, default 1.1) – Probe radius in Angstroms for dilation (water radius).
k_solvent (float, default 1.1) – Solvent scattering scale factor.
b_solvent (float, default 50.0) – Solvent B-factor.
erosion_radius (float, default 0.9) – Radius in Angstroms for erosion step.
transition (float, optional) – Gaussian smoothing sigma for mask edges (default: radius/4 in voxels). Avoids ringing artifacts.
optimize_phase (bool, default True) – Whether to optimize phase offset parameter.
initial_phase_offset (float, default 0.0) – Initial phase offset in radians.
verbose (int, default 1) – Verbosity level.
float_type (torch.dtype, default torch.float32) – Floating point data type.
device (torch.device, default: configured device.current) – Device for tensor operations.
- get_solvent_mask()[source]
Generate solvent mask following Phenix’s three-step process.
- Step 1 (dilation): classify voxels around each atom as protein
(inside VdW), boundary (between VdW and VdW+solvent_radius), or bulk solvent (further out). Built in chunks over atoms so peak memory is O(atom_chunk_size × N_box_voxels) rather than O(N_atoms × N_box_voxels) — critical because for typical macromolecule + grid combinations the dense form is multi-GB.
- Step 2 (symmetry expansion): transform the sparse ASU protein /
boundary voxel indices through each symop and scatter into the P1 grid masks.
- Step 3 (erosion): a boundary voxel becomes solvent if any voxel
within
erosion_radiusof it is bulk solvent. Implemented as a single F.conv3d with a precomputed spherical kernel and circular padding — replaces the previous Python-loop + per-voxel-neighbourhood expansion that itself ran out of memory on chunks of 10^6 boundary voxels.
- Returns:
Solvent mask (boolean) where True = solvent.
- Return type:
- get_rec_solvent(hkl)[source]
Compute solvent structure factors.
Uses the standard crystallographic approach: compute SFs from the solvent mask. The mask represents regions where bulk solvent scattering occurs.
- Parameters:
hkl (torch.Tensor) – Miller indices.
- Returns:
Complex solvent structure factors.
- Return type:
- forward(hkl, update_fsol=False, F_protein=None)[source]
Compute solvent contribution to structure factors at given HKL.
This method is differentiable with respect to k_solvent, b_solvent, and phase_offset parameters.
The solvent model:
Takes the binary solvent mask
Smooths it with Gaussian filter (σ=1.5 voxels) to create soft edges
Computes structure factors via FFT
Applies B-factor damping: exp(-B * s²) where s = sin(θ)/λ
If optimize_phase=True and F_protein provided: blends mask phases with protein phases phase_offset controls the blend: 0=use mask phases, ±π=use protein phases
Scales by k_solvent
- Parameters:
hkl (torch.Tensor) – Miller indices, shape (N, 3).
update_fsol (bool, default False) – Whether to update solvent structure factors.
F_protein (torch.Tensor, optional) – Protein structure factors, used for phase blending.
- Returns:
Complex solvent structure factors, shape (N,).
- Return type:
- parameters()[source]
Return an iterator over module parameters.
This is typically passed to an optimizer.
- Args:
- recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields:
Parameter: module parameter
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)