import torch
[docs]
def binned_correlation(x, y, bins):
"""
Compute binned correlation between x and y based on bins.
Args:
x (torch.Tensor): First set of values, shape (N,) or (B, N). The function will broadcast if one of x or y has shape (1, N) or (N,).
y (torch.Tensor): Second set of values, shape (N,) or (B, N).
bins (torch.Tensor): Bin indices for each value, shape (N,).
Returns:
torch.Tensor: Binned correlation coefficients, shape (num_bins,) or (B, num_bins).
"""
num_bins = bins.max().item() + 1
if x.dim() == 1:
x = x.unsqueeze(0)
if y.dim() == 1:
y = y.unsqueeze(0)
max_first_dim = max(x.shape[0], y.shape[0])
if x.shape[0] != y.shape[0]:
if x.shape[0] == 1:
x = x.expand(max_first_dim, -1)
elif y.shape[0] == 1:
y = y.expand(max_first_dim, -1)
else:
raise ValueError("x and y must have the same first dimension or one of them must be 1 or squeezed.")
# Batched case: x, y have shape (B, N)
B, N = x.shape
bins_expanded = bins.unsqueeze(0).expand(B, -1) # (B, N)
# Compute bin counts (same for all batches)
bin_counts = torch.zeros(num_bins, dtype=x.dtype, device=x.device)
bin_counts = torch.scatter_add(bin_counts, 0, bins, torch.ones(N, dtype=x.dtype, device=x.device)).clamp(min=1.0)
bin_counts = bin_counts.unsqueeze(0) # (1, num_bins) for broadcasting
# Compute means per bin per batch
mean_x = torch.zeros(B, num_bins, dtype=x.dtype, device=x.device)
mean_y = torch.zeros(B, num_bins, dtype=x.dtype, device=x.device)
mean_x = torch.scatter_add(mean_x, 1, bins_expanded, x) / bin_counts
mean_y = torch.scatter_add(mean_y, 1, bins_expanded, y) / bin_counts
# Center the data
x = x - torch.gather(mean_x, 1, bins_expanded)
y = y - torch.gather(mean_y, 1, bins_expanded)
xy = x * y
xx = x * x
yy = y * y
bin_sums_xy = torch.zeros(B, num_bins, dtype=x.dtype, device=x.device)
bin_sums_xx = torch.zeros(B, num_bins, dtype=x.dtype, device=x.device)
bin_sums_yy = torch.zeros(B, num_bins, dtype=x.dtype, device=x.device)
bin_sums_xy = torch.scatter_add(bin_sums_xy, 1, bins_expanded, xy) / bin_counts
bin_sums_xx = torch.scatter_add(bin_sums_xx, 1, bins_expanded, xx) / bin_counts
bin_sums_yy = torch.scatter_add(bin_sums_yy, 1, bins_expanded, yy) / bin_counts
corr = bin_sums_xy / torch.sqrt(bin_sums_xx * bin_sums_yy).clamp(min=1e-6)
return corr