torchref.base.alignment.scoring module

torchref.base.alignment.scoring.binned_correlation(x, y, bins)[source]

Compute binned correlation between x and y based on bins. Args:

x (torch.Tensor): First set of values, shape (N,) or (B, N). The function will broadcast if one of x or y has shape (1, N) or (N,). y (torch.Tensor): Second set of values, shape (N,) or (B, N). bins (torch.Tensor): Bin indices for each value, shape (N,).

Returns:

torch.Tensor: Binned correlation coefficients, shape (num_bins,) or (B, num_bins).