};
void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
- AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] {
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] {
Dist<scalar_t>::apply_pdist(result, self, p);
});
}
static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
- AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_backward", [&] {
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] {
Dist<scalar_t>::apply_backward_pdist(result, grad, self, p, dist);
});
}
static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
- AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist", [&] {
+ AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] {
Dist<scalar_t>::apply_cdist(result, x1, x2, p);
});
}
static void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
- AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist_backward", [&] {
+ AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_backward", [&] {
Dist<scalar_t>::apply_backward_cdist(result, grad, x1, x2, p, dist);
});
}