"""
PyTorch implementation of French-Wilson conversion from intensities to structure factors.
Reference: French, S. & Wilson, K. (1978). Acta Cryst. A34, 517-525
Based on Phenix implementation in cctbx/french_wilson.py
Usage - PyTorch Module (Recommended)::
import torch
from french_wilson_pytorch import FrenchWilsonModule
# Miller indices for your reflections
hkl = torch.tensor([[1, 2, 3], [2, 0, 0], [0, 3, 0], [1, 1, 1]])
# Cell: [a, b, c, alpha, beta, gamma] in Å and degrees
cell = [50.0, 60.0, 70.0, 90.0, 90.0, 90.0]
# Create module (does all preprocessing)
fw_module = FrenchWilsonModule(hkl, cell, space_group='P212121')
# Apply conversion (can be called repeatedly with different I, sigma_I)
I = torch.tensor([100.0, 50.0, 30.0, 200.0])
sigma_I = torch.tensor([10.0, 8.0, 7.0, 15.0])
F, sigma_F = fw_module(I, sigma_I)
print(f"F = {F}")
Usage - Functional API (for one-off conversions)::
from french_wilson_pytorch import french_wilson_auto
F, sigma_F, valid = french_wilson_auto(
I, sigma_I, hkl, d_spacings, space_group='P212121'
)
"""
import torch
import torch.nn as nn
from torchref.base import math_torch
from torchref.config import get_float_dtype
from torchref.symmetry import SpaceGroup, SpaceGroupLike
from torchref.utils.device_mixin import DeviceMixin
# Acentric lookup tables from French-Wilson supplement (1978)
AC_ZJ = torch.tensor(
[
0.226,
0.230,
0.235,
0.240,
0.246,
0.251,
0.257,
0.263,
0.270,
0.276,
0.283,
0.290,
0.298,
0.306,
0.314,
0.323,
0.332,
0.341,
0.351,
0.362,
0.373,
0.385,
0.397,
0.410,
0.424,
0.439,
0.454,
0.470,
0.487,
0.505,
0.525,
0.545,
0.567,
0.590,
0.615,
0.641,
0.668,
0.698,
0.729,
0.762,
0.798,
0.835,
0.875,
0.917,
0.962,
1.009,
1.059,
1.112,
1.167,
1.226,
1.287,
1.352,
1.419,
1.490,
1.563,
1.639,
1.717,
1.798,
1.882,
1.967,
2.055,
2.145,
2.236,
2.329,
2.422,
2.518,
2.614,
2.710,
2.808,
2.906,
3.004,
],
dtype=torch.float32,
)
AC_ZJ_SD = torch.tensor(
[
0.217,
0.221,
0.226,
0.230,
0.235,
0.240,
0.245,
0.250,
0.255,
0.261,
0.267,
0.273,
0.279,
0.286,
0.292,
0.299,
0.307,
0.314,
0.322,
0.330,
0.339,
0.348,
0.357,
0.367,
0.377,
0.387,
0.398,
0.409,
0.421,
0.433,
0.446,
0.459,
0.473,
0.488,
0.503,
0.518,
0.535,
0.551,
0.568,
0.586,
0.604,
0.622,
0.641,
0.660,
0.679,
0.698,
0.718,
0.737,
0.757,
0.776,
0.795,
0.813,
0.831,
0.848,
0.865,
0.881,
0.895,
0.909,
0.921,
0.933,
0.943,
0.953,
0.961,
0.968,
0.974,
0.980,
0.984,
0.988,
0.991,
0.994,
0.996,
],
dtype=torch.float32,
)
AC_ZF = torch.tensor(
[
0.423,
0.428,
0.432,
0.437,
0.442,
0.447,
0.453,
0.458,
0.464,
0.469,
0.475,
0.482,
0.488,
0.495,
0.502,
0.509,
0.516,
0.524,
0.532,
0.540,
0.549,
0.557,
0.567,
0.576,
0.586,
0.597,
0.608,
0.619,
0.631,
0.643,
0.656,
0.670,
0.684,
0.699,
0.714,
0.730,
0.747,
0.765,
0.783,
0.802,
0.822,
0.843,
0.865,
0.887,
0.911,
0.935,
0.960,
0.987,
1.014,
1.042,
1.070,
1.100,
1.130,
1.161,
1.192,
1.224,
1.257,
1.289,
1.322,
1.355,
1.388,
1.421,
1.454,
1.487,
1.519,
1.551,
1.583,
1.615,
1.646,
1.676,
1.706,
],
dtype=torch.float32,
)
AC_ZF_SD = torch.tensor(
[
0.216,
0.218,
0.220,
0.222,
0.224,
0.226,
0.229,
0.231,
0.234,
0.236,
0.239,
0.241,
0.244,
0.247,
0.250,
0.253,
0.256,
0.259,
0.262,
0.266,
0.269,
0.272,
0.276,
0.279,
0.283,
0.287,
0.291,
0.295,
0.298,
0.302,
0.307,
0.311,
0.315,
0.319,
0.324,
0.328,
0.332,
0.337,
0.341,
0.345,
0.349,
0.353,
0.357,
0.360,
0.364,
0.367,
0.369,
0.372,
0.374,
0.375,
0.376,
0.377,
0.377,
0.377,
0.376,
0.374,
0.372,
0.369,
0.366,
0.362,
0.358,
0.353,
0.348,
0.343,
0.338,
0.332,
0.327,
0.321,
0.315,
0.310,
0.304,
],
dtype=torch.float32,
)
# Centric lookup tables from French-Wilson supplement (1978)
C_ZJ = torch.tensor(
[
0.114,
0.116,
0.119,
0.122,
0.124,
0.127,
0.130,
0.134,
0.137,
0.141,
0.145,
0.148,
0.153,
0.157,
0.162,
0.166,
0.172,
0.177,
0.183,
0.189,
0.195,
0.202,
0.209,
0.217,
0.225,
0.234,
0.243,
0.253,
0.263,
0.275,
0.287,
0.300,
0.314,
0.329,
0.345,
0.363,
0.382,
0.402,
0.425,
0.449,
0.475,
0.503,
0.534,
0.567,
0.603,
0.642,
0.684,
0.730,
0.779,
0.833,
0.890,
0.952,
1.018,
1.089,
1.164,
1.244,
1.327,
1.416,
1.508,
1.603,
1.703,
1.805,
1.909,
2.015,
2.123,
2.233,
2.343,
2.453,
2.564,
2.674,
2.784,
2.894,
3.003,
3.112,
3.220,
3.328,
3.435,
3.541,
3.647,
3.753,
3.962,
],
dtype=torch.float32,
)
C_ZJ_SD = torch.tensor(
[
0.158,
0.161,
0.165,
0.168,
0.172,
0.176,
0.179,
0.184,
0.188,
0.192,
0.197,
0.202,
0.207,
0.212,
0.218,
0.224,
0.230,
0.236,
0.243,
0.250,
0.257,
0.265,
0.273,
0.282,
0.291,
0.300,
0.310,
0.321,
0.332,
0.343,
0.355,
0.368,
0.382,
0.397,
0.412,
0.428,
0.445,
0.463,
0.481,
0.501,
0.521,
0.543,
0.565,
0.589,
0.613,
0.638,
0.664,
0.691,
0.718,
0.745,
0.773,
0.801,
0.828,
0.855,
0.881,
0.906,
0.929,
0.951,
0.971,
0.989,
1.004,
1.018,
1.029,
1.038,
1.044,
1.049,
1.052,
1.054,
1.054,
1.053,
1.051,
1.049,
1.047,
1.044,
1.041,
1.039,
1.036,
1.034,
1.031,
1.029,
1.028,
],
dtype=torch.float32,
)
C_ZF = torch.tensor(
[
0.269,
0.272,
0.276,
0.279,
0.282,
0.286,
0.289,
0.293,
0.297,
0.301,
0.305,
0.309,
0.314,
0.318,
0.323,
0.328,
0.333,
0.339,
0.344,
0.350,
0.356,
0.363,
0.370,
0.377,
0.384,
0.392,
0.400,
0.409,
0.418,
0.427,
0.438,
0.448,
0.460,
0.471,
0.484,
0.498,
0.512,
0.527,
0.543,
0.560,
0.578,
0.597,
0.618,
0.639,
0.662,
0.687,
0.713,
0.740,
0.769,
0.800,
0.832,
0.866,
0.901,
0.938,
0.976,
1.016,
1.057,
1.098,
1.140,
1.183,
1.227,
1.270,
1.313,
1.356,
1.398,
1.439,
1.480,
1.519,
1.558,
1.595,
1.632,
1.667,
1.701,
1.735,
1.767,
1.799,
1.829,
1.859,
1.889,
1.917,
1.945,
],
dtype=torch.float32,
)
C_ZF_SD = torch.tensor(
[
0.203,
0.205,
0.207,
0.209,
0.211,
0.214,
0.216,
0.219,
0.222,
0.224,
0.227,
0.230,
0.233,
0.236,
0.239,
0.243,
0.246,
0.250,
0.253,
0.257,
0.261,
0.265,
0.269,
0.273,
0.278,
0.283,
0.288,
0.293,
0.298,
0.303,
0.309,
0.314,
0.320,
0.327,
0.333,
0.340,
0.346,
0.353,
0.361,
0.368,
0.375,
0.383,
0.390,
0.398,
0.405,
0.413,
0.420,
0.427,
0.433,
0.440,
0.445,
0.450,
0.454,
0.457,
0.459,
0.460,
0.460,
0.458,
0.455,
0.451,
0.445,
0.438,
0.431,
0.422,
0.412,
0.402,
0.392,
0.381,
0.370,
0.360,
0.349,
0.339,
0.330,
0.321,
0.312,
0.304,
0.297,
0.290,
0.284,
0.278,
0.272,
],
dtype=torch.float32,
)
[docs]
def interpolate_table(
h: torch.Tensor, table: torch.Tensor, h_min: float = -4.0
) -> torch.Tensor:
"""
Interpolate values from French-Wilson lookup table.
Parameters
----------
h : torch.Tensor
Normalized parameter (tensor of any shape).
table : torch.Tensor
Lookup table tensor (1D).
h_min : float, optional
Minimum h value. Default is -4.0.
Returns
-------
torch.Tensor
Interpolated values (same shape as h).
"""
# Map h to table index: point = 10.0 * (h - h_min)
# For h_min = -4.0, this gives point = 10.0 * (h + 4.0)
point = 10.0 * (h - h_min)
point = torch.clamp(point, 0.0, len(table) - 1.001) # Clamp to valid range
# Linear interpolation
pt_1 = point.long()
pt_2 = torch.clamp(pt_1 + 1, max=len(table) - 1)
delta = point - pt_1.float()
# Interpolate: (1-delta)*table[pt_1] + delta*table[pt_2]
val_1 = table[pt_1]
val_2 = table[pt_2]
result = (1.0 - delta) * val_1 + delta * val_2
return result
[docs]
def french_wilson_acentric(
I: torch.Tensor,
sigma_I: torch.Tensor,
mean_intensity: torch.Tensor,
h_min: float = -4.0,
i_sig_min: float = -3.7,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
French-Wilson conversion for acentric reflections.
Parameters
----------
I : torch.Tensor
Measured intensities (any shape).
sigma_I : torch.Tensor
Standard deviations of intensities (same shape as I).
mean_intensity : torch.Tensor
Mean intensity for each reflection's resolution bin (same shape as I).
h_min : float, optional
Minimum h value for rejection. Default is -4.0.
i_sig_min : float, optional
Minimum I/sigma_I for rejection. Default is -3.7 (h_min + 0.3).
Returns
-------
F : torch.Tensor
Structure factor amplitudes (same shape as I).
sigma_F : torch.Tensor
Standard deviations of F (same shape as I).
valid_mask : torch.Tensor
Boolean mask indicating valid (not rejected) reflections.
"""
device = I.device
dtype = I.dtype
# Move lookup tables to same device and dtype
ac_zj = AC_ZJ.to(device=device, dtype=dtype)
ac_zj_sd = AC_ZJ_SD.to(device=device, dtype=dtype)
ac_zf = AC_ZF.to(device=device, dtype=dtype)
ac_zf_sd = AC_ZF_SD.to(device=device, dtype=dtype)
# Compute normalized parameter h
h = (I / sigma_I) - (sigma_I / mean_intensity)
# Clamp h to valid table range [-4.0, ...] to avoid extrapolation issues
# Very weak reflections (h < h_min) get the boundary value from lookup table
h_clamped = torch.clamp(h, min=h_min)
# Initialize outputs
F = torch.zeros_like(I)
sigma_F = torch.zeros_like(I)
# Case 1: Small h (h < 3.0) - use lookup tables
small_h_mask = h_clamped < 3.0
if small_h_mask.any():
h_small = h_clamped[small_h_mask]
sigma_I_small = sigma_I[small_h_mask]
# Interpolate from tables
zf = interpolate_table(h_small, ac_zf, h_min=h_min)
zf_sd = interpolate_table(h_small, ac_zf_sd, h_min=h_min)
F[small_h_mask] = zf * torch.sqrt(sigma_I_small)
sigma_F[small_h_mask] = zf_sd * torch.sqrt(sigma_I_small)
# Case 2: Large h (h >= 3.0) - use asymptotic formula
large_h_mask = h_clamped >= 3.0
if large_h_mask.any():
h_large = h_clamped[large_h_mask]
sigma_I_large = sigma_I[large_h_mask]
J = h_large * sigma_I_large
F_large = torch.sqrt(J)
sigma_F_large = 0.5 * (sigma_I_large / F_large)
F[large_h_mask] = F_large
sigma_F[large_h_mask] = sigma_F_large
# Create valid mask for reference (but don't use it to zero out values)
i_over_sig = I / sigma_I
valid_mask = (i_over_sig >= i_sig_min) & (h >= h_min)
return F, sigma_F, valid_mask
[docs]
def french_wilson_centric(
I: torch.Tensor,
sigma_I: torch.Tensor,
mean_intensity: torch.Tensor,
h_min: float = -4.0,
i_sig_min: float = -3.7,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
French-Wilson conversion for centric reflections.
Parameters
----------
I : torch.Tensor
Measured intensities (any shape).
sigma_I : torch.Tensor
Standard deviations of intensities (same shape as I).
mean_intensity : torch.Tensor
Mean intensity for each reflection's resolution bin (same shape as I).
h_min : float, optional
Minimum h value for rejection. Default is -4.0.
i_sig_min : float, optional
Minimum I/sigma_I for rejection. Default is -3.7 (h_min + 0.3).
Returns
-------
F : torch.Tensor
Structure factor amplitudes (same shape as I).
sigma_F : torch.Tensor
Standard deviations of F (same shape as I).
valid_mask : torch.Tensor
Boolean mask indicating valid (not rejected) reflections.
"""
device = I.device
dtype = I.dtype
# Move lookup tables to same device and dtype
c_zj = C_ZJ.to(device=device, dtype=dtype)
c_zj_sd = C_ZJ_SD.to(device=device, dtype=dtype)
c_zf = C_ZF.to(device=device, dtype=dtype)
c_zf_sd = C_ZF_SD.to(device=device, dtype=dtype)
# Compute normalized parameter h (note factor of 2 for centric!)
h = (I / sigma_I) - (sigma_I / (2.0 * mean_intensity))
# Clamp h to valid table range [-4.0, ...] to avoid extrapolation issues
# Very weak reflections (h < h_min) get the boundary value from lookup table
h_clamped = torch.clamp(h, min=h_min)
# Initialize outputs
F = torch.zeros_like(I)
sigma_F = torch.zeros_like(I)
# Case 1: Small h (h < 4.0) - use lookup tables
small_h_mask = h_clamped < 4.0
if small_h_mask.any():
h_small = h_clamped[small_h_mask]
sigma_I_small = sigma_I[small_h_mask]
# Interpolate from tables
zf = interpolate_table(h_small, c_zf, h_min=h_min)
zf_sd = interpolate_table(h_small, c_zf_sd, h_min=h_min)
F[small_h_mask] = zf * torch.sqrt(sigma_I_small)
sigma_F[small_h_mask] = zf_sd * torch.sqrt(sigma_I_small)
# Case 2: Large h (h >= 4.0) - use extended asymptotic formula
large_h_mask = h_clamped >= 4.0
if large_h_mask.any():
h_large = h_clamped[large_h_mask]
sigma_I_large = sigma_I[large_h_mask]
# Extended asymptotic expansion (Phenix extension)
h_2 = 1.0 / (h_large * h_large)
h_4 = h_2 * h_2
h_6 = h_2 * h_4
# Posterior mean of F
post_F = torch.sqrt(h_large) * (
1.0 - (3.0 / 8.0) * h_2 - (87.0 / 128.0) * h_4 - (2889.0 / 1024.0) * h_6
)
# Posterior standard deviation of F
post_sig_F = torch.sqrt(
h_large * ((1.0 / 4.0) * h_2 + (15.0 / 32.0) * h_4 + (273.0 / 128.0) * h_6)
)
F[large_h_mask] = post_F * torch.sqrt(sigma_I_large)
sigma_F[large_h_mask] = post_sig_F * torch.sqrt(sigma_I_large)
# Create valid mask for reference (but don't use it to zero out values)
i_over_sig = I / sigma_I
valid_mask = (i_over_sig >= i_sig_min) & (h >= h_min)
return F, sigma_F, valid_mask
[docs]
def french_wilson(
I: torch.Tensor,
sigma_I: torch.Tensor,
mean_intensity: torch.Tensor,
is_centric: torch.Tensor = None,
h_min: float = -4.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
French-Wilson conversion from intensities to structure factors.
Automatically handles both centric and acentric reflections.
Parameters
----------
I : torch.Tensor
Measured intensities of shape (...).
sigma_I : torch.Tensor
Standard deviations of intensities of shape (...).
mean_intensity : torch.Tensor
Mean intensity for each reflection's resolution bin of shape (...).
is_centric : torch.Tensor, optional
Boolean mask indicating centric reflections of shape (...).
If None, assumes all reflections are acentric.
h_min : float, optional
Minimum h value for rejection. Default is -4.0.
Returns
-------
F : torch.Tensor
Structure factor amplitudes of shape (...).
sigma_F : torch.Tensor
Standard deviations of F of shape (...).
valid_mask : torch.Tensor
Boolean mask indicating valid (not rejected) reflections of shape (...).
Examples
--------
::
I = torch.tensor([100.0, 5.0, -15.0, 200.0])
sigma_I = torch.tensor([10.0, 10.0, 10.0, 15.0])
mean_I = torch.tensor([80.0, 80.0, 80.0, 150.0])
F, sigma_F, valid = french_wilson(I, sigma_I, mean_I)
print(f"F = {F}")
"""
i_sig_min = h_min + 0.3
# Initialize outputs
F = torch.zeros_like(I)
sigma_F = torch.zeros_like(I)
valid_mask = torch.zeros_like(I, dtype=torch.bool)
if is_centric is None:
# All acentric
F, sigma_F, valid_mask = french_wilson_acentric(
I, sigma_I, mean_intensity, h_min, i_sig_min
)
else:
# Process acentric reflections
acentric_mask = ~is_centric
if acentric_mask.any():
F_acen, sigma_F_acen, valid_acen = french_wilson_acentric(
I[acentric_mask],
sigma_I[acentric_mask],
mean_intensity[acentric_mask],
h_min,
i_sig_min,
)
F[acentric_mask] = F_acen
sigma_F[acentric_mask] = sigma_F_acen
valid_mask[acentric_mask] = valid_acen
# Process centric reflections
if is_centric.any():
F_cen, sigma_F_cen, valid_cen = french_wilson_centric(
I[is_centric],
sigma_I[is_centric],
mean_intensity[is_centric],
h_min,
i_sig_min,
)
F[is_centric] = F_cen
sigma_F[is_centric] = sigma_F_cen
valid_mask[is_centric] = valid_cen
return F, sigma_F, valid_mask
[docs]
def is_centric_from_hkl(
hkl: torch.Tensor, space_group: SpaceGroupLike = "P1"
) -> torch.Tensor:
"""
Determine if reflections are centric based on Miller indices and space group.
Uses symmetry operations to check if reflections are invariant under
inversion through the origin (Friedel mates). A reflection is centric
if -h,-k,-l is symmetry equivalent to h,k,l.
Parameters
----------
hkl : torch.Tensor
Miller indices of shape (..., 3).
space_group : str, int, or gemmi.SpaceGroup, optional
Space group specification. Default is "P1".
Returns
-------
torch.Tensor
Boolean mask of shape (...), True for centric reflections.
"""
original_shape = hkl.shape[:-1]
hkl_flat = hkl.reshape(-1, 3)
n_reflections = hkl_flat.shape[0]
# Get symmetry operations from the SpaceGroup class
float_dtype = get_float_dtype()
spacegroup = SpaceGroup(space_group, dtype=float_dtype, device=hkl.device)
# Convert HKL to the configured float dtype for symmetry operations
hkl_float = hkl_flat.to(float_dtype) # Shape: (n_reflections, 3)
# Apply all spacegroup operations to all reflections at once
# For reciprocal space (Miller indices), only rotation applies, not translation
# hkl_float shape: (n_reflections, 3)
# spacegroup.apply_to_hkl returns shape: (n_reflections, 3, n_ops)
hkl_sym = spacegroup.apply_to_hkl(hkl_float)
# Compute Friedel mates: -h, -k, -l
# Shape: (n_reflections, 3, 1) to broadcast against (n_reflections, 3, n_ops)
friedel_hkl = -hkl_float.unsqueeze(-1) # Shape: (n_reflections, 3, 1)
# Check if any spacegroup operation produces the Friedel mate
# Round to nearest integer (Miller indices should be integers)
hkl_sym_rounded = torch.round(hkl_sym)
# Compute difference for all reflections and all spacegroup operations
# Shape: (n_reflections, 3, n_ops)
diff = torch.abs(hkl_sym_rounded - friedel_hkl)
# A reflection is centric if ANY spacegroup operation maps it to its Friedel mate
# Check if all 3 components (h,k,l) match (diff < 0.5) for any operation
# Shape: (n_reflections, n_ops) after checking all 3 components match
matches = torch.all(diff < 0.5, dim=1) # Check all 3 Miller indices match
# A reflection is centric if it matches for ANY spacegroup operation
# Shape: (n_reflections,)
is_centric = torch.any(matches, dim=1)
return is_centric.reshape(original_shape)
[docs]
def get_centric_acentric_masks(
hkl: torch.Tensor, space_group: SpaceGroupLike = "P1"
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get both centric and acentric masks for reflections.
Convenience function that returns both masks explicitly.
Parameters
----------
hkl : torch.Tensor
Miller indices of shape (..., 3).
space_group : str, int, or gemmi.SpaceGroup, optional
Space group specification. Default is "P1".
Returns
-------
centric_mask : torch.Tensor
Boolean mask of shape (...), True for centric reflections.
acentric_mask : torch.Tensor
Boolean mask of shape (...), True for acentric reflections.
"""
centric_mask = is_centric_from_hkl(hkl, space_group)
acentric_mask = ~centric_mask
return centric_mask, acentric_mask
[docs]
def estimate_mean_intensity_by_resolution(
I: torch.Tensor, d_spacings: torch.Tensor, n_bins: int = 60, min_per_bin: int = 40
) -> torch.Tensor:
"""
Estimate mean intensity for each reflection based on resolution binning.
Uses linear interpolation between bin centers for smooth mean intensity
estimates.
Parameters
----------
I : torch.Tensor
Measured intensities of shape (n_reflections,).
d_spacings : torch.Tensor
Resolution (d-spacing) for each reflection of shape (n_reflections,).
n_bins : int, optional
Number of resolution bins. Default is 60.
min_per_bin : int, optional
Minimum reflections per bin. Default is 40.
Returns
-------
torch.Tensor
Estimated mean intensity for each reflection of shape (n_reflections,).
"""
n_reflections = len(I)
# Adjust number of bins to ensure minimum per bin
reflections_per_bin = max(min_per_bin, n_reflections // n_bins)
actual_n_bins = max(1, n_reflections // reflections_per_bin)
# Sort by resolution (d-spacing, high to low)
sort_idx = torch.argsort(d_spacings, descending=True)
I_sorted = I[sort_idx]
d_sorted = d_spacings[sort_idx]
# Compute bin boundaries and mean intensities using vectorized operations
# Create bin indices for each sorted reflection
bin_indices = torch.arange(n_reflections, device=I.device) // reflections_per_bin
bin_indices = torch.clamp(bin_indices, max=actual_n_bins - 1)
# Use scatter_add to compute sum of intensities per bin
bin_sums = torch.zeros(actual_n_bins, dtype=I.dtype, device=I.device)
bin_counts = torch.zeros(actual_n_bins, dtype=torch.long, device=I.device)
bin_sums.scatter_add_(0, bin_indices, I_sorted)
bin_counts.scatter_add_(0, bin_indices, torch.ones_like(bin_indices))
# Compute mean intensity per bin
bin_means = bin_sums / bin_counts.to(I.dtype)
# Compute bin centers using scatter for min/max d-spacings
bin_d_max = torch.full(
(actual_n_bins,), -float("inf"), dtype=I.dtype, device=I.device
)
bin_d_min = torch.full(
(actual_n_bins,), float("inf"), dtype=I.dtype, device=I.device
)
bin_d_max.scatter_reduce_(
0, bin_indices, d_sorted, reduce="amax", include_self=False
)
bin_d_min.scatter_reduce_(
0, bin_indices, d_sorted, reduce="amin", include_self=False
)
bin_centers = (bin_d_max + bin_d_min) / 2.0
# Now interpolate for each reflection based on ORIGINAL (unsorted) d_spacings
# bin_centers are in descending order (high to low d-spacing)
# For each d_spacing, find which two bins it falls between
# torch.searchsorted expects ascending order, so we flip
bin_centers_ascending = bin_centers.flip(0)
# Find the insertion point for each d_spacing in ascending order
# right=True means if d_spacing equals a bin center, use the bin to the right
insert_idx = torch.searchsorted(bin_centers_ascending, d_spacings, right=True)
# Convert back to descending order indexing
# In descending order, the "left" bin is at position (n_bins - insert_idx)
# and the "right" bin is at position (n_bins - insert_idx - 1)
left_idx = actual_n_bins - insert_idx
right_idx = left_idx - 1
# Clamp to valid range
left_idx = torch.clamp(left_idx, 0, actual_n_bins - 1)
right_idx = torch.clamp(right_idx, 0, actual_n_bins - 1)
# Handle edge cases first (beyond first or last bin)
# If d >= first bin center, use first bin
beyond_first = d_spacings >= bin_centers[0]
# If d <= last bin center, use last bin
beyond_last = d_spacings <= bin_centers[-1]
# Get bin centers and means for interpolation
d1 = bin_centers[left_idx]
d2 = bin_centers[right_idx]
m1 = bin_means[left_idx]
m2 = bin_means[right_idx]
# Linear interpolation weight
d_diff = d1 - d2
# Avoid division by zero
safe_d_diff = torch.where(
torch.abs(d_diff) > 1e-10, d_diff, torch.ones_like(d_diff)
)
weight = (d1 - d_spacings) / safe_d_diff
weight = torch.clamp(weight, 0.0, 1.0)
# Interpolate
mean_I = (1 - weight) * m1 + weight * m2
# Apply edge case handling
mean_I = torch.where(beyond_first, bin_means[0], mean_I)
mean_I = torch.where(beyond_last, bin_means[-1], mean_I)
return mean_I
[docs]
def french_wilson_auto(
I: torch.Tensor,
sigma_I: torch.Tensor,
hkl: torch.Tensor,
d_spacings: torch.Tensor,
space_group: SpaceGroupLike = "P1",
n_bins: int = 60,
min_per_bin: int = 40,
h_min: float = -4.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Automatic French-Wilson conversion with binning and centric determination.
This function automatically:
1. Bins reflections by resolution
2. Calculates mean intensity per bin
3. Determines centric vs acentric from Miller indices
4. Applies appropriate French-Wilson conversion
Parameters
----------
I : torch.Tensor
Measured intensities of shape (n_reflections,).
sigma_I : torch.Tensor
Standard deviations of intensities of shape (n_reflections,).
hkl : torch.Tensor
Miller indices of shape (n_reflections, 3).
d_spacings : torch.Tensor
Resolution (d-spacing) for each reflection of shape (n_reflections,).
space_group : str, int, or gemmi.SpaceGroup, optional
Space group specification. Default is "P1".
n_bins : int, optional
Number of resolution bins. Default is 60.
min_per_bin : int, optional
Minimum reflections per bin. Default is 40.
h_min : float, optional
Minimum h value for rejection. Default is -4.0.
Returns
-------
F : torch.Tensor
Structure factor amplitudes of shape (n_reflections,).
sigma_F : torch.Tensor
Standard deviations of F of shape (n_reflections,).
valid_mask : torch.Tensor
Boolean mask indicating valid (not rejected) reflections.
Examples
--------
::
hkl = torch.tensor([[1, 2, 3], [2, 0, 0], [0, 3, 0], [1, 1, 1]])
I = torch.tensor([100.0, 50.0, 30.0, 200.0])
sigma_I = torch.tensor([10.0, 8.0, 7.0, 15.0])
d_spacings = torch.tensor([2.5, 3.0, 2.8, 2.0])
F, sigma_F, valid = french_wilson_auto(I, sigma_I, hkl, d_spacings, "P212121")
"""
# Step 1: Estimate mean intensity by resolution
mean_intensity = estimate_mean_intensity_by_resolution(
I, d_spacings, n_bins=n_bins, min_per_bin=min_per_bin
)
# Step 2: Determine centric reflections from Miller indices
is_centric = is_centric_from_hkl(hkl, space_group=space_group)
# Step 3: Apply French-Wilson conversion
F, sigma_F, valid_mask = french_wilson(
I, sigma_I, mean_intensity, is_centric=is_centric, h_min=h_min
)
return F, sigma_F, valid_mask
[docs]
class FrenchWilson(DeviceMixin, nn.Module):
"""
PyTorch module for French-Wilson conversion from intensities to structure factors.
Pre-computes all necessary metadata (d-spacings, centric flags, resolution bins)
during initialization, so forward pass only needs I and sigma_I.
Parameters
----------
hkl : torch.Tensor
Miller indices of shape (n_reflections, 3), integer tensor.
cell : torch.Tensor
Unit cell parameters [a, b, c, alpha, beta, gamma] in Å and degrees.
space_group : str, int, or gemmi.SpaceGroup, optional
Space group specification (e.g., 'P21', 4, gemmi.SpaceGroup('P 21')). Default is "P1".
n_bins : int, optional
Number of resolution bins for mean intensity estimation. Default is 60.
min_per_bin : int, optional
Minimum reflections per bin. Default is 40.
h_min : float, optional
Minimum h value for rejection. Default is -4.0.
verbose : int, optional
Verbosity level (0=silent, 1=basic, 2=detailed). Default is 1.
Attributes
----------
hkl : torch.Tensor
Miller indices.
d_spacings : torch.Tensor
Resolution for each reflection in Å.
is_centric : torch.Tensor
Boolean mask for centric reflections.
Examples
--------
::
hkl = torch.tensor([[1, 2, 3], [2, 0, 0], [0, 3, 0], [1, 1, 1]])
cell = [50.0, 60.0, 70.0, 90.0, 90.0, 90.0]
fw_module = FrenchWilson(hkl, cell, 'P212121')
I = torch.tensor([100.0, 50.0, 30.0, 200.0])
sigma_I = torch.tensor([10.0, 8.0, 7.0, 15.0])
F, sigma_F = fw_module(I, sigma_I)
"""
[docs]
def __init__(
self,
hkl: torch.Tensor,
cell: torch.Tensor,
space_group: SpaceGroupLike = "P1",
n_bins: int = 60,
min_per_bin: int = 40,
h_min: float = -4.0,
verbose: int = 1,
):
super().__init__()
# Store parameters
self.n_reflections = len(hkl)
self.space_group = space_group
self.n_bins = n_bins
self.min_per_bin = min_per_bin
self.h_min = h_min
self.verbose = verbose
# Register HKL as buffer (will be moved to device with model)
self.register_buffer("hkl", hkl.long())
# Calculate d-spacings from cell and HKL
d_spacings = math_torch.get_d_spacing(hkl, cell)
self.register_buffer("d_spacings", d_spacings)
# Determine centric reflections
is_centric = is_centric_from_hkl(hkl, space_group)
self.register_buffer("is_centric", is_centric)
# Verbosity level 1: Basic initialization info (most important)
if self.verbose >= 1:
print("FrenchWilsonModule initialized:")
print(f" Reflections: {self.n_reflections}")
print(f" Resolution: {d_spacings.min():.2f} - {d_spacings.max():.2f} Å")
print(f" Space group: {space_group}")
print(
f" Centric: {is_centric.sum()} ({100*is_centric.sum()/self.n_reflections:.1f}%)"
)
# Verbosity level 2: Additional detailed info (less important)
if self.verbose >= 2:
print(f" Binning: {n_bins} bins, min {min_per_bin} reflections/bin")
print(f" Rejection threshold: h_min = {h_min}")
print(f" Device: {hkl.device}")
[docs]
def forward(
self, I: torch.Tensor, sigma_I: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply French-Wilson conversion.
Args:
I: Measured intensities, shape (n_reflections,)
sigma_I: Standard deviations of intensities, shape (n_reflections,)
Returns:
F: Structure factor amplitudes, shape (n_reflections,)
sigma_F: Standard deviations of F, shape (n_reflections,)
"""
# Check for NaN values in input
nan_mask = torch.isnan(I) | torch.isnan(sigma_I)
# If all values are NaN, return NaN arrays
if nan_mask.all():
return torch.full_like(I, float("nan")), torch.full_like(
sigma_I, float("nan")
)
# Filter out NaN values and corresponding metadata
valid_mask = ~nan_mask
I_clean = I[valid_mask]
sigma_I_clean = sigma_I[valid_mask]
d_spacings_clean = self.d_spacings[valid_mask]
is_centric_clean = self.is_centric[valid_mask]
# Estimate mean intensity by resolution (only for valid reflections)
mean_intensity = estimate_mean_intensity_by_resolution(
I_clean, d_spacings_clean, n_bins=self.n_bins, min_per_bin=self.min_per_bin
)
# Apply French-Wilson conversion
F_clean, sigma_F_clean, _ = french_wilson(
I_clean,
sigma_I_clean,
mean_intensity,
is_centric=is_centric_clean,
h_min=self.h_min,
)
# Create output arrays with NaNs for invalid reflections
F_full = torch.full_like(I, float("nan"))
sigma_F_full = torch.full_like(sigma_I, float("nan"))
# Insert computed values for valid reflections
F_full[valid_mask] = F_clean
sigma_F_full[valid_mask] = sigma_F_clean
return F_full, sigma_F_full