score:1

Accepted answer

Suppose we are using the L1 distance:

import torch
# data and target
a = torch.randn(100, 16, 3)
b = torch.randn(5, 3)

# Reshape the tensors
a = a.unsqueeze(1)
b = b.unsqueeze(0).unsqueeze(2)

print(a.shape, b.shape)

# Compute distance
dist = (a-b).abs().sum(3)
print(dist.shape)