Source code for torchref.scaling.solvent

"""
A class for modelling solvent contribution to structure factors.
"""

import torch
import torch.nn as nn

from torchref.base import (
    extract_structure_factor_from_grid,
    get_scattering_vectors,
    ifft,
)
from torchref.base.electron_density.main import _get_radius_offsets
from torchref.config import get_default_device, get_float_dtype
from torchref.utils.debug_utils import DebugMixin
from torchref.utils.utils import TensorDict, ModuleReference
from torchref.utils.device_mixin import DeviceMixin


[docs] class SolventModel(DeviceMixin, DebugMixin, nn.Module): """ SolventModel to compute solvent contribution to structure factors using Phenix-like approach. Supports two initialization patterns: 1. Empty initialization (for state_dict loading):: solvent = SolventModel() # Creates empty shell solvent.load_state_dict(torch.load('solvent.pt')) 2. Full initialization with model:: solvent = SolventModel(model, k_solvent=0.35, b_solvent=46.0) Attributes ---------- model : ModelFT or None The atomic model for structure factor calculations. device : torch.device Device for tensor operations. verbose : int Verbosity level. float_type : torch.dtype Floating point data type. solvent_radius : float Probe radius in Angstroms for dilation. erosion_radius : float Radius in Angstroms for erosion step. optimize_phase : bool Whether to optimize phase offset parameter. log_k_solvent : torch.nn.Parameter Log of solvent scattering scale factor. b_solvent : torch.nn.Parameter Solvent B-factor. phase_offset : torch.nn.Parameter or buffer Phase offset in radians. """
[docs] def __init__( self, model=None, radius=1.1, k_solvent=1.1, b_solvent=50.0, erosion_radius=0.9, transition=None, optimize_phase=True, initial_phase_offset=0.0, verbose=1, float_type=get_float_dtype(), device=get_default_device(), ): """ Initialize SolventModel. If model is provided, fully initializes the solvent model. If not provided (empty init), creates a shell ready for load_state_dict(). Parameters ---------- model : ModelFT, optional The atomic model used for structure factor calculations (optional for empty init). radius : float, default 1.1 Probe radius in Angstroms for dilation (water radius). k_solvent : float, default 1.1 Solvent scattering scale factor. b_solvent : float, default 50.0 Solvent B-factor. erosion_radius : float, default 0.9 Radius in Angstroms for erosion step. transition : float, optional Gaussian smoothing sigma for mask edges (default: radius/4 in voxels). Avoids ringing artifacts. optimize_phase : bool, default True Whether to optimize phase offset parameter. initial_phase_offset : float, default 0.0 Initial phase offset in radians. verbose : int, default 1 Verbosity level. float_type : torch.dtype, default torch.float32 Floating point data type. device : torch.device, default: configured device.current Device for tensor operations. """ super(SolventModel, self).__init__() self.device = device self.verbose = verbose self.float_type = float_type self.solvent_radius = radius self.erosion_radius = erosion_radius self.optimize_phase = optimize_phase self._cache = TensorDict() # Empty initialization if model is None: self.model = None self.max_radius_angstrom = None self.transition = transition # Register parameters with default values (will be overwritten by load_state_dict) self.log_k_solvent = nn.Parameter( torch.log( torch.tensor(k_solvent, dtype=self.float_type, device=self.device) ) ) self.b_solvent = nn.Parameter( torch.tensor(b_solvent, dtype=self.float_type, device=self.device) ) if self.optimize_phase: self.phase_offset = nn.Parameter( torch.tensor( initial_phase_offset, dtype=self.float_type, device=self.device ) ) else: self.register_buffer( "phase_offset", torch.tensor(0.0, dtype=self.float_type, device=self.device), ) return # Full initialization with model self.model = ModuleReference(model) # Store reference to model self.model.get_vdw_radii() # Ensure VdW radii are available assert self.model, "Model is not initialized" if model.real_space_grid == None: model.setup_grid() # Phenix-style parameters self.solvent_radius = radius # For dilation (accessible surface) self.erosion_radius = erosion_radius # For erosion (contact surface) # For find_relevant_voxels: need to search far enough to capture accessible surface # Maximum possible distance is max(VdW) + solvent_radius self.max_radius_angstrom = self.model.get_vdw_radii().max() + radius if not isinstance(k_solvent, torch.Tensor): k_solvent = torch.tensor( k_solvent, dtype=self.float_type, device=self.device ) else: k_solvent = k_solvent.to(dtype=self.float_type, device=self.device) if not isinstance(b_solvent, torch.Tensor): b_solvent = torch.tensor( b_solvent, dtype=self.float_type, device=self.device ) else: b_solvent = b_solvent.to(dtype=self.float_type, device=self.device) self.log_k_solvent = nn.Parameter(torch.log(k_solvent)) self.b_solvent = nn.Parameter(b_solvent) # Transition width for Gaussian smoothing (in voxels) if transition is not None: self.transition = transition else: # Default: use a fraction of solvent_radius converted to voxels self.transition = self.model.get_radius(radius) / 4.0 # Phase offset parameter to align solvent phases with protein phases # This is critical because FFT of a mask gives arbitrary phases self.optimize_phase = optimize_phase if self.optimize_phase: self.phase_offset = nn.Parameter( torch.tensor( initial_phase_offset, dtype=self.float_type, device=self.device ) ) else: self.register_buffer( "phase_offset", torch.tensor(0.0, dtype=self.float_type, device=self.device), ) self._cache = TensorDict()
[docs] def get_solvent_mask(self): """ Generate solvent mask following Phenix's three-step process. Step 1 (dilation): classify voxels around each atom as protein (inside VdW), boundary (between VdW and VdW+solvent_radius), or bulk solvent (further out). Built in chunks over atoms so peak memory is O(atom_chunk_size × N_box_voxels) rather than O(N_atoms × N_box_voxels) — critical because for typical macromolecule + grid combinations the dense form is multi-GB. Step 2 (symmetry expansion): transform the sparse ASU protein / boundary voxel indices through each symop and scatter into the P1 grid masks. Step 3 (erosion): a boundary voxel becomes solvent if any voxel within ``erosion_radius`` of it is bulk solvent. Implemented as a single F.conv3d with a precomputed spherical kernel and circular padding — replaces the previous Python-loop + per-voxel-neighbourhood expansion that itself ran out of memory on chunks of 10^6 boundary voxels. Returns ------- torch.Tensor Solvent mask (boolean) where True = solvent. """ import torch.nn.functional as F if self.verbose > 1: print("\n=== Phenix-Style Bulk Solvent Mask Calculation ===") print(f"Solvent radius (dilation): {self.solvent_radius:.2f} Å") print(f"Shrink truncation radius (erosion): {self.erosion_radius:.2f} Å") xyz = self.model.xyz() # (N_atoms, 3) vdw_radii = self.model.get_vdw_radii() # (N_atoms,) self.real_space_grid = self.model.real_space_grid inv_frac = self.model.inv_fractional_matrix frac = self.model.fractional_matrix with torch.no_grad(): spacegroup = self.model.fft.spacegroup n_ops = spacegroup.n_ops grid_shape = self.real_space_grid.shape[:-1] device = self.model.device n_atoms = xyz.shape[0] # --- Step 1: dilation, chunked over atoms --- # Reuses the SF splatting pattern from # `torchref.base.electron_density.main._add_isotropic_cpu_fused`: # - cached spherical voxel offsets via _get_radius_offsets # (avoids rebuilding the meshgrid each call) # - direct fractional voxel positions from integer indices # (no real_space_grid gather) # - PBC via `diff_frac - round(diff_frac)` (no Cartesian # round-trip via two matmuls) # - r² via metric-tensor einsum (single fused op) # ATOM_CHUNK caps working memory at ~chunk * N_voxels_in_sphere # times a handful of (float32, 3)-shape transients. 256 keeps # peak in the few-hundred-MB range even on the finest grids; # the SF code's 1024 OOMs on tight cgroups for 1DAW because # our intermediates are denser per atom. ATOM_CHUNK = 256 voxel_size = (self.real_space_grid[3, 3, 3] - self.real_space_grid[2, 2, 2]) local_offsets = _get_radius_offsets( voxel_size, self.max_radius_angstrom, device ) # (R, 3) int — sphere voxels relative to atom center grid_dims = torch.tensor(grid_shape, dtype=torch.long, device=device) grid_shape_float = grid_dims.float() inv_grid = 1.0 / grid_shape_float G = frac.T @ frac # metric tensor: r²_cart = diff_frac · G · diff_frac xyz_frac = xyz @ inv_frac.T # (N, 3) xyz_frac_wrapped = xyz_frac % 1.0 center_idx = torch.round( xyz_frac_wrapped * grid_shape_float ).long() # (N, 3) protein_chunks = [] boundary_chunks = [] for s in range(0, n_atoms, ATOM_CHUNK): e = min(s + ATOM_CHUNK, n_atoms) # Wrapped voxel indices: (C, R, 3) int vi = ( center_idx[s:e].unsqueeze(1) + local_offsets.unsqueeze(0) ) % grid_dims # Direct fractional voxel positions (skip real_space_grid gather) voxel_frac = vi.float() * inv_grid # (C, R, 3) # PBC fractional diff diff_frac = voxel_frac - xyz_frac[s:e].unsqueeze(1) diff_frac = diff_frac - torch.round(diff_frac) del voxel_frac # Squared Cartesian distance via metric tensor r_sq = torch.einsum("avi,ij,avj->av", diff_frac, G, diff_frac) del diff_frac vdw_c = vdw_radii[s:e] vdw_sq = (vdw_c ** 2).unsqueeze(1) rcut_sq = ((vdw_c + self.solvent_radius) ** 2).unsqueeze(1) is_protein = r_sq < vdw_sq is_boundary = (~is_protein) & (r_sq < rcut_sq) del r_sq voxel_flat = vi.reshape(-1, 3) del vi p_idx = is_protein.flatten().nonzero(as_tuple=True)[0] b_idx = is_boundary.flatten().nonzero(as_tuple=True)[0] del is_protein, is_boundary protein_chunks.append(voxel_flat[p_idx]) boundary_chunks.append(voxel_flat[b_idx]) protein_voxels = torch.cat(protein_chunks, dim=0) if protein_chunks else \ torch.empty((0, 3), dtype=torch.long, device=device) boundary_voxels = torch.cat(boundary_chunks, dim=0) if boundary_chunks else \ torch.empty((0, 3), dtype=torch.long, device=device) del protein_chunks, boundary_chunks # --- Step 2: symmetry expansion via index transform --- protein_mask = torch.zeros(grid_shape, dtype=torch.bool, device=device) boundary_mask = torch.zeros(grid_shape, dtype=torch.bool, device=device) float_dtype = get_float_dtype() for op_idx in range(n_ops): if op_idx == 0: p_idx = protein_voxels b_idx = boundary_voxels else: R = spacegroup.matrices[op_idx].to(device=device, dtype=float_dtype) t = spacegroup.translations[op_idx].to(device=device, dtype=float_dtype) gd = grid_dims.to(float_dtype) p_frac = protein_voxels.to(float_dtype) / gd p_idx = (torch.round((p_frac @ R.T + t) * gd) % grid_dims).long() del p_frac b_frac = boundary_voxels.to(float_dtype) / gd b_idx = (torch.round((b_frac @ R.T + t) * gd) % grid_dims).long() del b_frac protein_mask[p_idx[:, 0], p_idx[:, 1], p_idx[:, 2]] = True boundary_mask[b_idx[:, 0], b_idx[:, 1], b_idx[:, 2]] = True boundary_mask = boundary_mask & (~protein_mask) definitely_solvent = ~(protein_mask | boundary_mask) if self.verbose > 2: total_voxels = protein_mask.numel() print( f"After symmetry: protein={protein_mask.sum().item()} " f"boundary={boundary_mask.sum().item()} " f"solvent={definitely_solvent.sum().item()} / {total_voxels}" ) # --- Step 3: erosion --- # Semantics: a boundary voxel becomes solvent iff any voxel # within `erosion_radius` of it is bulk solvent. Equivalently, # dilate `definitely_solvent` by the spherical structuring # element, then intersect with `boundary_mask`. # # Two paths, dispatched on device — the underlying op picks # extremely different cost profiles on CPU vs GPU: # - CPU: roll-OR over the ~123 cached sphere offsets is # memory-bandwidth-bound and ~5x faster than F.conv3d's # MKLDNN path on a small (~9³) kernel. # - GPU: F.conv3d fuses to a single launch and beats 123 # separate roll launches (each its own kernel) by ~2x. # Both produce the exact same mask. if torch.device(device).type == "cuda": half_k = int( torch.ceil(self.erosion_radius / voxel_size.min()).item() ) K = 2 * half_k + 1 offs = torch.arange( -half_k, half_k + 1, device=device, dtype=voxel_size.dtype ) dx = offs.view(-1, 1, 1) * voxel_size[0] dy = offs.view(1, -1, 1) * voxel_size[1] dz = offs.view(1, 1, -1) * voxel_size[2] kernel = ( (dx * dx + dy * dy + dz * dz) <= self.erosion_radius ** 2 ).to(self.log_k_solvent.dtype).view(1, 1, K, K, K) solv_float = definitely_solvent.to(self.log_k_solvent.dtype).view( 1, 1, *grid_shape ) solv_padded = F.pad( solv_float, (half_k,) * 6, mode="circular" ) neighbour_count = F.conv3d(solv_padded, kernel) dilated_solvent = neighbour_count.squeeze(0).squeeze(0) > 0.5 del solv_float, solv_padded, neighbour_count else: sphere_offsets = _get_radius_offsets( voxel_size, self.erosion_radius, device ) # (R, 3) int dilated_solvent = torch.zeros_like(definitely_solvent) for i in range(sphere_offsets.shape[0]): dz_, dy_, dx_ = sphere_offsets[i].tolist() dilated_solvent |= torch.roll( definitely_solvent, shifts=(dz_, dy_, dx_), dims=(0, 1, 2), ) voxels_to_flip = boundary_mask & dilated_solvent protein_with_boundary = (protein_mask | boundary_mask) & (~voxels_to_flip) solvent_mask = ~protein_with_boundary self.register_buffer("protein_mask", protein_with_boundary) self.register_buffer("solvent_mask", solvent_mask) if self.verbose > 1: total_voxels = self.solvent_mask.numel() n_solv = self.solvent_mask.sum().item() print( f"Total solvent voxels: {n_solv} / {total_voxels} " f"({100.0 * n_solv / total_voxels:.2f}%)" ) assert torch.isfinite( self.solvent_mask.float() ).all(), "Non-finite values in solvent mask" return self.solvent_mask
[docs] def update_solvent(self): self.get_solvent_mask() self.smooth_solvent_mask()
[docs] def smooth_solvent_mask(self): if not hasattr(self, "solvent_mask"): raise ValueError( "Solvent mask not computed. Call get_solvent_mask() first." ) import torch.nn.functional as F mask_float = self.solvent_mask.to(dtype=self.log_k_solvent.dtype) sigma = self.transition kernel_size = int(4 * sigma + 1) if kernel_size % 2 == 0: kernel_size += 1 x = torch.arange( kernel_size, dtype=self.log_k_solvent.dtype, device=self.device ) x = x - kernel_size // 2 gauss_1d = torch.exp(-(x**2) / (2 * sigma**2)) gauss_1d = gauss_1d / gauss_1d.sum() pad = kernel_size // 2 mask = mask_float.unsqueeze(0).unsqueeze(0) # (1, 1, D, H, W) # Separable Gaussian: three sequential 1D conv3d passes (one per # spatial axis). For a separable kernel, this is mathematically # identical to the full 3D outer-product conv but costs O(3·K·V) # instead of O(K³·V). Pad once on all three axes with the circular # mode required by periodic crystallographic boundaries; each # successive 1D conv shrinks only its own axis back to the original # size. mask = F.pad(mask, (pad, pad, pad, pad, pad, pad), mode="circular") mask = F.conv3d(mask, gauss_1d.view(1, 1, 1, 1, kernel_size)) mask = F.conv3d(mask, gauss_1d.view(1, 1, 1, kernel_size, 1)) mask = F.conv3d(mask, gauss_1d.view(1, 1, kernel_size, 1, 1)) mask_smoothed = mask.squeeze(0).squeeze(0) self.register_buffer("mask_smoothed", mask_smoothed) assert torch.isfinite( self.mask_smoothed ).all(), "Non-finite values in solvent mask" return self.mask_smoothed
[docs] def get_rec_solvent(self, hkl): """ Compute solvent structure factors. Uses the standard crystallographic approach: compute SFs from the solvent mask. The mask represents regions where bulk solvent scattering occurs. Parameters ---------- hkl : torch.Tensor Miller indices. Returns ------- torch.Tensor Complex solvent structure factors. """ assert hasattr( self, "mask_smoothed" ), "Smoothed solvent mask not computed. Call smooth_solvent_mask() first." fsol = extract_structure_factor_from_grid( ifft(self.mask_smoothed, self.model.cell.volume), hkl ).detach() assert torch.isfinite( fsol ).all(), "Non-finite values in solvent structure factors" return fsol
[docs] def forward(self, hkl, update_fsol=False, F_protein=None): """ Compute solvent contribution to structure factors at given HKL. This method is differentiable with respect to k_solvent, b_solvent, and phase_offset parameters. The solvent model: 1. Takes the binary solvent mask 2. Smooths it with Gaussian filter (σ=1.5 voxels) to create soft edges 3. Computes structure factors via FFT 4. Applies B-factor damping: exp(-B * s²) where s = sin(θ)/λ 5. If optimize_phase=True and F_protein provided: blends mask phases with protein phases phase_offset controls the blend: 0=use mask phases, ±π=use protein phases 6. Scales by k_solvent Parameters ---------- hkl : torch.Tensor Miller indices, shape (N, 3). update_fsol : bool, default False Whether to update solvent structure factors. F_protein : torch.Tensor, optional Protein structure factors, used for phase blending. Returns ------- torch.Tensor Complex solvent structure factors, shape (N,). """ # Lightweight fingerprint: (data_ptr, version, numel) — avoids SHA-1 hkl_key = (hkl.data_ptr(), hkl._version, hkl.numel()) if not update_fsol and hkl_key in self._cache: f_sol = self._cache[hkl_key] else: f_sol = self.get_rec_solvent(hkl) self._cache[hkl_key] = f_sol # Calculate scattering vector magnitude: s = sin(θ)/λ # Note: get_scattering_vectors returns h* = (h·a*, k·b*, l·c*) # For the Debye-Waller factor, we need s = |h*|/2 = sin(θ)/λ scattering_vectors = get_scattering_vectors( hkl, self.model.cell, recB=self.model.recB ) s = torch.norm(scattering_vectors, dim=1) / 2.0 # This is sin(θ)/λ s_squared = s**2 # Now s² is correct for B-factor formula # Apply B-factor damping: exp(-B * s²) # The Debye-Waller factor for isotropic displacement b_solvent = self.b_solvent k_solvent = torch.exp(self.log_k_solvent.clamp(min=-10.0, max=10.0)) exp = -b_solvent.clamp(min=-500.0, max=500.0) * s_squared b_factor_term = torch.exp( exp.clamp(min=-10.0, max=10.0) ) # Clamp to avoid overflow # Phase handling if self.optimize_phase and F_protein is not None: f_mask_amp = torch.abs(f_sol) mask_phases = torch.angle(f_sol) protein_phases = torch.angle(F_protein) # Interpolate phases using phase_offset as a blending parameter # cos(phase_offset) = 1: use mask phases # cos(phase_offset) = -1: use inverted protein phases blend_factor = torch.cos(self.phase_offset) blended_phase = ( mask_phases * (1 + blend_factor) / 2 + (protein_phases + torch.pi) * (1 - blend_factor) / 2 ) phase_adjusted_f_sol = f_mask_amp * torch.exp(1j * blended_phase) elif self.optimize_phase: # Apply global phase offset phase_adjusted_f_sol = f_sol * torch.exp(1j * self.phase_offset) else: # No phase adjustment - use mask phases as-is4 phase_adjusted_f_sol = f_sol # Scale by k_solvent and apply B-factor f_solvent = k_solvent * phase_adjusted_f_sol * b_factor_term assert torch.isfinite( f_solvent ).all(), "Non-finite values in solvent structure factors" return f_solvent
[docs] def parameters(self): return [self.log_k_solvent, self.b_solvent] + ( [self.phase_offset] if self.optimize_phase else [] )