import numpy as np
import torch
from typing import TYPE_CHECKING, Dict
from torchref.utils.stats import (
VERBOSITY_DEBUG,
VERBOSITY_DETAILED,
VERBOSITY_STANDARD,
StatEntry,
stat,
)
from .base import ADPTarget
if TYPE_CHECKING:
from torchref.model.model import Model
[docs]
class ADPLocalityTarget(ADPTarget):
"""
Proximity-based ADP restraint using K nearest neighbors.
Uses a spatial cell-list (O(N) memory, O(N·k) time) instead of
a full N×N distance matrix, so it scales to arbitrarily large
structures without memory issues.
Parameters
----------
model : Model
Reference to Model object.
k_neighbors : int, optional
Number of nearest neighbors to consider. Default is 50.
correlation_length : float, optional
Distance scale for weight decay in Angstrom. Default is 5.0.
scale : float, optional
Scaling factor for loss magnitude. Default is 5.0.
exclude_bonded : bool, optional
Exclude directly bonded atoms. Default is True.
verbose : int, optional
Verbosity level. Default is 0.
"""
name: str = "adp/locality"
[docs]
def __init__(
self,
model: "Model" = None,
k_neighbors: int = 50,
correlation_length: float = 5.0,
scale: float = 5.0,
exclude_bonded: bool = True,
verbose: int = 0,
):
super().__init__(model, verbose, target_value=0.3, sigma=0.2)
self.register_buffer(
"_k_neighbors", torch.tensor(k_neighbors, dtype=torch.int64)
)
self.register_buffer("_correlation_length", torch.tensor(correlation_length))
self.register_buffer("_scale", torch.tensor(scale))
self.exclude_bonded = exclude_bonded
# Cache for neighbor indices and distances
self._neighbor_indices = None # (N, k_neighbors)
self._neighbor_distances = None # (N, k_neighbors)
self._last_xyz_hash = None
@property
def k_neighbors(self) -> int:
return self._k_neighbors.item()
@k_neighbors.setter
def k_neighbors(self, value: int):
self._k_neighbors.fill_(value)
@property
def correlation_length(self) -> float:
return self._correlation_length.item()
@correlation_length.setter
def correlation_length(self, value: float):
self._correlation_length.fill_(value)
@property
def scale(self) -> float:
return self._scale.item()
@scale.setter
def scale(self, value: float):
self._scale.fill_(value)
# ------------------------------------------------------------------
# Spatial-hash k-NN (O(N) memory)
# ------------------------------------------------------------------
def _build_neighbor_list(self) -> None:
"""
Build k-nearest-neighbor list using a spatial cell-list.
Memory usage is O(N·k) for the output arrays and O(N) for the
cell-list bookkeeping, instead of O(N²) for a full distance matrix.
"""
xyz = self.model.xyz()
device = xyz.device
n_atoms = xyz.shape[0]
k = min(self.k_neighbors, n_atoms - 1)
coords = xyz.detach().cpu().numpy()
# Cell size should cover the kth-neighbor distance. For proteins
# k=50 neighbors are typically within ~8-10 Å, so 12 Å is safe.
cell_size = 12.0
xyz_min = coords.min(axis=0)
cell_idx = ((coords - xyz_min) / cell_size).astype(np.int64)
grid_dims = cell_idx.max(axis=0) + 1
gx, gy, gz = int(grid_dims[0]), int(grid_dims[1]), int(grid_dims[2])
gyz = gy * gz
flat = cell_idx[:, 0] * gyz + cell_idx[:, 1] * gz + cell_idx[:, 2]
order = np.argsort(flat)
sorted_flat = flat[order]
unique_cells, first_idx, counts = np.unique(
sorted_flat, return_index=True, return_counts=True
)
n_unique = len(unique_cells)
# start[i] .. start[i+1] are the atoms in unique cell i
starts = np.empty(n_unique + 1, dtype=np.int64)
starts[0] = 0
starts[1:] = np.cumsum(counts)
# flat_cell -> unique index (-1 = empty)
n_grid = gx * gyz
cell_lookup = np.full(n_grid, -1, dtype=np.int64)
cell_lookup[unique_cells] = np.arange(n_unique, dtype=np.int64)
# 27 neighbor offsets (self + all adjacent cells)
offsets = []
for dx in range(-1, 2):
for dy in range(-1, 2):
for dz in range(-1, 2):
offsets.append((dx, dy, dz, dx * gyz + dy * gz + dz))
# For each atom, collect candidate neighbors and keep top-k
all_neighbor_idx = np.zeros((n_atoms, k), dtype=np.int64)
all_neighbor_dist = np.full((n_atoms, k), np.inf, dtype=np.float32)
# atom_cell[i] = unique-cell index for atom i
atom_cell = np.empty(n_atoms, dtype=np.int64)
atom_cell[order] = np.repeat(np.arange(n_unique), counts)
for ci in range(n_unique):
cell_flat = int(unique_cells[ci])
sa, ea = int(starts[ci]), int(starts[ci + 1])
atoms_a = order[sa:ea]
xyz_a = coords[atoms_a]
cx = cell_flat // gyz
cy = (cell_flat % gyz) // gz
cz = cell_flat % gz
# Collect all candidate neighbor atoms from adjacent cells
cand_atoms_list = []
cand_xyz_list = []
for dx, dy, dz, _ in offsets:
ncx, ncy, ncz = cx + dx, cy + dy, cz + dz
if ncx < 0 or ncx >= gx or ncy < 0 or ncy >= gy or ncz < 0 or ncz >= gz:
continue
nb_flat = ncx * gyz + ncy * gz + ncz
nb_ci = int(cell_lookup[nb_flat])
if nb_ci < 0:
continue
sb, eb = int(starts[nb_ci]), int(starts[nb_ci + 1])
cand_atoms_list.append(order[sb:eb])
cand_xyz_list.append(coords[order[sb:eb]])
if not cand_atoms_list:
continue
cand_atoms = np.concatenate(cand_atoms_list)
cand_xyz = np.concatenate(cand_xyz_list, axis=0)
# Distances from each atom in this cell to all candidates
# shape: (len(atoms_a), len(cand_atoms))
diff = xyz_a[:, None, :] - cand_xyz[None, :, :]
dist = np.sqrt((diff * diff).sum(axis=-1))
for li, ai in enumerate(atoms_a):
d = dist[li]
# Mask self
self_mask = cand_atoms == ai
d[self_mask] = np.inf
if len(d) <= k:
top_k_idx = np.argsort(d)[:k]
else:
top_k_idx = np.argpartition(d, k)[:k]
# Sort the top-k for deterministic order
sub_order = np.argsort(d[top_k_idx])
top_k_idx = top_k_idx[sub_order]
n_valid = min(k, len(top_k_idx))
all_neighbor_idx[ai, :n_valid] = cand_atoms[top_k_idx[:n_valid]]
all_neighbor_dist[ai, :n_valid] = d[top_k_idx[:n_valid]]
self._neighbor_indices = torch.from_numpy(all_neighbor_idx).to(device)
self._neighbor_distances = torch.from_numpy(all_neighbor_dist).to(device)
if self.verbose > 1:
mean_dist = float(all_neighbor_dist[all_neighbor_dist < np.inf].mean())
print(
f" Built K-NN list (spatial hash): k={k}, "
f"mean dist={mean_dist:.2f}A"
)
[docs]
def forward(self, recompute_neighbors: bool = False) -> torch.Tensor:
"""
Compute weighted MSE on log(B) differences with inverse-distance weights.
loss = scale * mean_ij [w_ij * (log(B_i) - log(B_j))^2]
where w_ij = 1 / (d_ij + eps)
"""
model_device = self.model.xyz().device
cache_stale = (
self._neighbor_indices is not None
and self._neighbor_indices.device != model_device
)
if recompute_neighbors or self._neighbor_indices is None or cache_stale:
self._build_neighbor_list()
adp = self.model.adp()
device = adp.device
n_atoms = len(adp)
if n_atoms == 0 or self._neighbor_indices is None:
return torch.tensor(0.0, device=device)
log_adp = torch.log(adp.clamp(min=1e-3))
indices = self._neighbor_indices
distances = self._neighbor_distances
neighbor_log_adp = log_adp[indices]
diff = log_adp.unsqueeze(1) - neighbor_log_adp
weights = 1 / (distances + 1e-6)
weighted_sq_diff = weights * (diff / 0.5) ** 2
loss = weighted_sq_diff.sum()
return loss
[docs]
def stats(self) -> Dict[str, any]:
"""Get locality restraint statistics."""
self._build_neighbor_list()
if self._neighbor_indices is None:
return {}
adp = self.model.adp().detach()
log_adp = torch.log(adp.clamp(min=1e-3))
indices = self._neighbor_indices
distances = self._neighbor_distances
neighbor_log_adp = log_adp[indices]
diff = log_adp.unsqueeze(1) - neighbor_log_adp
weights = torch.exp(-distances / self.correlation_length)
weighted_sq_diff = weights * (diff**2)
weighted_rms = torch.sqrt(weighted_sq_diff.sum() / weights.sum()).item()
loss = self.forward()
return {
"loss": stat(loss.item(), VERBOSITY_STANDARD),
"n_atoms": stat(len(adp), VERBOSITY_DEBUG),
"weighted_rms_log": stat(weighted_rms, VERBOSITY_DETAILED),
"rms_deviation_log": stat(
torch.sqrt((diff**2).mean()).item(), VERBOSITY_DETAILED
),
"max_deviation_log": stat(diff.abs().max().item(), VERBOSITY_DETAILED),
"k_neighbors": stat(self.k_neighbors, VERBOSITY_DEBUG),
"correlation_length": stat(self.correlation_length, VERBOSITY_DEBUG),
"scale": stat(self.scale, VERBOSITY_DEBUG),
"avg_neighbor_dist": stat(distances.mean().item(), VERBOSITY_DEBUG),
"max_neighbor_dist": stat(distances.max().item(), VERBOSITY_DEBUG),
"avg_weight": stat(weights.mean().item(), VERBOSITY_DEBUG),
}