torchref.scaling.solvent module

A class for modelling solvent contribution to structure factors.

class torchref.scaling.solvent.SolventModel(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]

Bases: DeviceMixin, DebugMixin, Module

SolventModel to compute solvent contribution to structure factors using Phenix-like approach.

Supports two initialization patterns:

  1. Empty initialization (for state_dict loading):

    solvent = SolventModel()  # Creates empty shell
    solvent.load_state_dict(torch.load('solvent.pt'))
    
  2. Full initialization with model:

    solvent = SolventModel(model, k_solvent=0.35, b_solvent=46.0)
    
model

The atomic model for structure factor calculations.

Type:

ModelFT or None

device

Device for tensor operations.

Type:

torch.device

verbose

Verbosity level.

Type:

int

float_type

Floating point data type.

Type:

torch.dtype

solvent_radius

Probe radius in Angstroms for dilation.

Type:

float

erosion_radius

Radius in Angstroms for erosion step.

Type:

float

optimize_phase

Whether to optimize phase offset parameter.

Type:

bool

log_k_solvent

Log of solvent scattering scale factor.

Type:

torch.nn.Parameter

b_solvent

Solvent B-factor.

Type:

torch.nn.Parameter

phase_offset

Phase offset in radians.

Type:

torch.nn.Parameter or buffer

__init__(model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=torch.float32, device=device(type='cpu'))[source]

Initialize SolventModel.

If model is provided, fully initializes the solvent model. If not provided (empty init), creates a shell ready for load_state_dict().

Parameters:
  • model (ModelFT, optional) – The atomic model used for structure factor calculations (optional for empty init).

  • radius (float, default 1.1) – Probe radius in Angstroms for dilation (water radius).

  • k_solvent (float, default 1.1) – Solvent scattering scale factor.

  • b_solvent (float, default 50.0) – Solvent B-factor.

  • erosion_radius (float, default 0.9) – Radius in Angstroms for erosion step.

  • transition (float, optional) – Gaussian smoothing sigma for mask edges (default: radius/4 in voxels). Avoids ringing artifacts.

  • optimize_phase (bool, default True) – Whether to optimize phase offset parameter.

  • initial_phase_offset (float, default 0.0) – Initial phase offset in radians.

  • verbose (int, default 1) – Verbosity level.

  • float_type (torch.dtype, default torch.float32) – Floating point data type.

  • device (torch.device, default: configured device.current) – Device for tensor operations.

get_solvent_mask()[source]

Generate solvent mask following Phenix’s three-step process.

Step 1 (dilation): classify voxels around each atom as protein

(inside VdW), boundary (between VdW and VdW+solvent_radius), or bulk solvent (further out). Built in chunks over atoms so peak memory is O(atom_chunk_size × N_box_voxels) rather than O(N_atoms × N_box_voxels) — critical because for typical macromolecule + grid combinations the dense form is multi-GB.

Step 2 (symmetry expansion): transform the sparse ASU protein /

boundary voxel indices through each symop and scatter into the P1 grid masks.

Step 3 (erosion): a boundary voxel becomes solvent if any voxel

within erosion_radius of it is bulk solvent. Implemented as a single F.conv3d with a precomputed spherical kernel and circular padding — replaces the previous Python-loop + per-voxel-neighbourhood expansion that itself ran out of memory on chunks of 10^6 boundary voxels.

Returns:

Solvent mask (boolean) where True = solvent.

Return type:

torch.Tensor

update_solvent()[source]
smooth_solvent_mask()[source]
get_rec_solvent(hkl)[source]

Compute solvent structure factors.

Uses the standard crystallographic approach: compute SFs from the solvent mask. The mask represents regions where bulk solvent scattering occurs.

Parameters:

hkl (torch.Tensor) – Miller indices.

Returns:

Complex solvent structure factors.

Return type:

torch.Tensor

forward(hkl, update_fsol=False, F_protein=None)[source]

Compute solvent contribution to structure factors at given HKL.

This method is differentiable with respect to k_solvent, b_solvent, and phase_offset parameters.

The solvent model:

  1. Takes the binary solvent mask

  2. Smooths it with Gaussian filter (σ=1.5 voxels) to create soft edges

  3. Computes structure factors via FFT

  4. Applies B-factor damping: exp(-B * s²) where s = sin(θ)/λ

  5. If optimize_phase=True and F_protein provided: blends mask phases with protein phases phase_offset controls the blend: 0=use mask phases, ±π=use protein phases

  6. Scales by k_solvent

Parameters:
  • hkl (torch.Tensor) – Miller indices, shape (N, 3).

  • update_fsol (bool, default False) – Whether to update solvent structure factors.

  • F_protein (torch.Tensor, optional) – Protein structure factors, used for phase blending.

Returns:

Complex solvent structure factors, shape (N,).

Return type:

torch.Tensor

parameters()[source]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)