score:1
Accepted answer
Suppose we are using the L1 distance:
import torch
# data and target
a = torch.randn(100, 16, 3)
b = torch.randn(5, 3)
# Reshape the tensors
a = a.unsqueeze(1)
b = b.unsqueeze(0).unsqueeze(2)
print(a.shape, b.shape)
# Compute distance
dist = (a-b).abs().sum(3)
print(dist.shape)
Credit To: stackoverflow.com