const int64_t count = dist.numel();
Tensor buffer = at::empty({r2, r1, m}, result.options());
- AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist_cuda_backward", [&] {
+ AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] {
if (p == 1.0) {
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count);
} else if (p < 2.0) {