Refactor dispatcher (#17753)
authorIurii Zdebskyi <iuriiz@fb.com>
Thu, 7 Mar 2019 21:38:59 +0000 (13:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Mar 2019 21:41:54 +0000 (13:41 -0800)
Summary:
This is a side PR for a bool tensor feature. The idea of this change came from a feedback received in this [PR](https://github.com/pytorch/pytorch/pull/17376).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17753

Differential Revision: D14367989

Pulled By: izdeby

fbshipit-source-id: 4fa380e56e20f18e480be68920170dbc3a4eb91c

19 files changed:
aten/src/ATen/Dispatch.h
aten/src/ATen/native/Copy.cpp
aten/src/ATen/native/Scalar.cpp
aten/src/ATen/native/cpu/CopyKernel.cpp
aten/src/ATen/native/cpu/IndexKernel.cpp
aten/src/ATen/native/cuda/Activation.cu
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/cuda/BinaryOpsKernel.cu
aten/src/ATen/native/cuda/CUDAScalar.cu
aten/src/ATen/native/cuda/Copy.cu
aten/src/ATen/native/cuda/Distributions.cu
aten/src/ATen/native/cuda/IndexKernel.cu
aten/src/ATen/native/cuda/RangeFactories.cu
aten/src/ATen/native/cuda/SortingKthValue.cu
aten/src/ATen/native/cuda/TensorCompare.cu
aten/src/ATen/native/cuda/TensorFactories.cu
aten/src/ATen/native/cuda/TensorTransformations.cu
aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu
aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

index 7c8bff3..20d7f2d 100644 (file)
     }                                                                        \
   }()
 
-#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...)                      \
-  [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
-      default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
-    }                                                                        \
-  }()
+template <at::ScalarType N>
+struct MyTemplate;
 
-#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...)                           \
-  [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
-      AT_PRIVATE_CASE_TYPE(                                                  \
-          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)    \
-      AT_PRIVATE_CASE_TYPE(                                                  \
-          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)  \
-      default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
-    }                                                                        \
-  }()
+template<>
+struct MyTemplate<at::ScalarType::Half> {
+  using type = at::Half;
+};
 
-#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...)                   \
-  [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(                                                  \
-          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)    \
-      AT_PRIVATE_CASE_TYPE(                                                  \
-          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)  \
-      default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
-    }                                                                        \
+#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...)                    \
+  [&] {                                                                           \
+    const at::Type& the_type = TYPE;                                              \
+    switch (the_type.scalarType()) {                                              \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)            \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)             \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)           \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)             \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)             \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)            \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)           \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
+      default:                                                                    \
+        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'");      \
+    }                                                                             \
   }()
 
-#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...)          \
-  [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(                                                  \
-          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)    \
-      AT_PRIVATE_CASE_TYPE(                                                  \
-          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)  \
-      default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
-    }                                                                        \
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...)        \
+  [&] {                                                                           \
+    const at::Type& the_type = TYPE;                                              \
+    switch (the_type.scalarType()) {                                              \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)            \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)             \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)           \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)             \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)             \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)            \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)           \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
+      AT_PRIVATE_CASE_TYPE(                                                       \
+          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)         \
+      AT_PRIVATE_CASE_TYPE(                                                       \
+          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)       \
+      default:                                                                    \
+        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'");      \
+    }                                                                             \
   }()
index 871c274..a8c7bab 100644 (file)
@@ -20,7 +20,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
 template <typename self_T>
 void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
   AT_CHECK(self.numel() == src.numel(), "sizes do not match");
-  AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cpu", [&]() {
     _copy__cpu<self_T, scalar_t>(self, src);
   });
 }
@@ -42,8 +42,8 @@ Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) {
     _s_copy_from(src, self, non_blocking);
     return self;
   }
-  AT_DISPATCH_ALL_TYPES_AND_HALF(
-      self.type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, self.type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
   return self;
 }
 
@@ -58,8 +58,8 @@ void _copy_same_type_transpose_(Tensor& self, const Tensor& src) {
   }
   Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
 
-  AT_DISPATCH_ALL_TYPES_AND_HALF(
-      self.type(), "_copy_same_type_transpose_", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, self.type(), "_copy_same_type_transpose_", [&]() {
         scalar_t* sp = src.data<scalar_t>();
         scalar_t* rp = self.data<scalar_t>();
         scalar_t* bp = buf.data<scalar_t>();
@@ -114,12 +114,13 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) {
     } else {
 #ifdef _OPENMP
       if (!in_parallel_region()) {
-        AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
-          at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
-              self, src, [](scalar_t& self_val, const scalar_t& src_val) {
-                self_val = src_val;
-              });
-        });
+        AT_DISPATCH_ALL_TYPES_AND(
+          at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() {
+            at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
+                self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+                  self_val = src_val;
+                });
+          });
       } else {
         serial_path = true;
       }
@@ -132,12 +133,13 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) {
   }
 
   if (serial_path) {
-    AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
-      at::CPU_tensor_apply2<scalar_t, scalar_t>(
-          self, src, [](scalar_t& self_val, const scalar_t& src_val) {
-            self_val = src_val;
-          });
-    });
+    AT_DISPATCH_ALL_TYPES_AND(
+      at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() {
+        at::CPU_tensor_apply2<scalar_t, scalar_t>(
+            self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+              self_val = src_val;
+            });
+        });
   }
 }
 
index 4de16d7..c635c9d 100644 (file)
@@ -18,8 +18,8 @@ Scalar item(const Tensor& self) {
 
 Scalar _local_scalar_dense_cpu(const Tensor& self) {
   Scalar r;
-  AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(
-      self.type(), "_local_scalar_dense_cpu", [&] {
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
+    at::ScalarType::Half, self.type(), "_local_scalar_dense_cpu", [&] {
         scalar_t value = *self.data<scalar_t>();
         r = Scalar(value);
       });
index caf0364..03ef8df 100644 (file)
@@ -14,15 +14,16 @@ namespace {
 constexpr int64_t COPY_GRAIN_SIZE = 20000;
 
 static void copy_kernel_impl(Tensor& dst, const Tensor& src) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(dst.type(), "copy_kernel_impl", [&]() {
-    scalar_t* self_ptr = dst.data<scalar_t>();
-    scalar_t* src_ptr = src.data<scalar_t>();
-
-    auto sample = [&](int64_t begin, int64_t end) {
-      int64_t len = end - begin;
-      scalar_t* self_seg = self_ptr + begin;
-      scalar_t* src_seg = src_ptr + begin;
-      at::vec256::convert<scalar_t, scalar_t>(src_seg, self_seg, len);
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, dst.type(), "copy_kernel_impl", [&]() {
+      scalar_t* self_ptr = dst.data<scalar_t>();
+      scalar_t* src_ptr = src.data<scalar_t>();
+
+      auto sample = [&](int64_t begin, int64_t end) {
+        int64_t len = end - begin;
+        scalar_t* self_seg = self_ptr + begin;
+        scalar_t* src_seg = src_ptr + begin;
+        at::vec256::convert<scalar_t, scalar_t>(src_seg, self_seg, len);
     };
 
     parallel_for(0, dst.numel(), COPY_GRAIN_SIZE, sample);
index c89621d..0d93112 100644 (file)
@@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
 }
 
 void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(0), "index", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index", [&] {
     cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
       *(scalar_t*)dst = *(scalar_t*)(src + offset);
     });
@@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
 
 void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
   // NOTE: duplicate indices are only supported if accumulate is true.
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(0), "index_put", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index_put", [&] {
     if (accumulate) {
       // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
       // this needs to be thread-safe.
index b4b7894..83e879a 100644 (file)
@@ -286,7 +286,7 @@ void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t va
 }
 
 static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "threshold", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "threshold", [&] {
     threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
   });
 }
index 2ac7f5f..1f0184b 100644 (file)
@@ -493,7 +493,7 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c
           self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
   dim3 dim_block = cuda::getApplyBlock();
   dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
-  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), name, [&]{
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), name, [&]{
     triu_tril_kernel<scalar_t, upper>
       <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
         result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
index fa31a2d..5cb1212 100644 (file)
@@ -22,7 +22,7 @@ void add_kernel_impl(TensorIterator& iter, Scalar alpha_scalar) {
 }
 
 static void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "add", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "add", [&]() {
     add_kernel_impl<scalar_t>(iter, alpha_scalar);
   });
 }
@@ -74,7 +74,7 @@ void mul_kernel_impl(TensorIterator& iter) {
 }
 
 static void mul_kernel_cuda(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "mul", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "mul", [&]() {
     mul_kernel_impl<scalar_t>(iter);
   });
 }
index d4dc8d4..70b9236 100644 (file)
@@ -9,8 +9,8 @@ namespace native {
 
 Scalar _local_scalar_dense_cuda(const Tensor& self) {
   Scalar r;
-  AT_DISPATCH_ALL_TYPES_AND_HALF(
-      self.type(), "_local_scalar_dense_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, self.type(), "_local_scalar_dense_cuda", [&] {
         scalar_t value;
         cudaStream_t stream = at::cuda::getCurrentCUDAStream();
         AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
index cb96963..ab85b48 100644 (file)
@@ -169,7 +169,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) {
       cudaMemcpyHostToDevice,
       stream));
   AT_CUDA_CHECK(cudaStreamSynchronize(stream));
-  AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu", [&]() {
     copy_device_to_device<scalar_t, scalar_t>(dst, dst_contig);
   });
 }
@@ -202,7 +202,7 @@ void copy_from_cpu_async_(Tensor& dst, const Tensor& src) {
   CUDAGuard device_guard(dst.device());
   CUDAStream stream = getCurrentCUDAStream();
 
-  AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu_async", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu_async", [&]() {
     AT_CUDA_CHECK(cudaMemcpyAsync(
         dst.data<scalar_t>(),
         src.data<scalar_t>(),
@@ -225,7 +225,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) {
   CUDAGuard device_guard(src.device());
   CUDAStream stream = getCurrentCUDAStream();
 
-  AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_to_cpu_async", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_to_cpu_async", [&]() {
     AT_CUDA_CHECK(cudaMemcpyAsync(
         dst.data<scalar_t>(),
         src.data<scalar_t>(),
@@ -240,7 +240,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) {
 template <typename dst_T>
 void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) {
   AT_CHECK(dst.numel() == src.numel(), "sizes do not match");
-  AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cuda", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cuda", [&]() {
     if (dst.is_cuda() && src.is_cuda()) {
       copy_device_to_device<dst_T, scalar_t>(dst, src);
     } else if (dst.is_cuda()) {
@@ -279,7 +279,7 @@ namespace at {
 namespace native {
 
 Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy__cuda", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "_copy__cuda", [&]() {
     ::_copy__cuda<scalar_t>(self, src, non_blocking);
   });
   return self;
index 00a1b34..c262384 100644 (file)
@@ -210,21 +210,22 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
 
 Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
   auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
-  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_tensor_cuda_self_", [&] {
-    const at::Type& p_type = p.type();
-    using self_t = scalar_t;
-    auto seeds = next_philox_seed(gen, 10);
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] {
-      using p_t = scalar_t;
-      return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, seeds);
-    });
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, self.type(), "bernoulli_tensor_cuda_self_", [&] {
+      const at::Type& p_type = p.type();
+      using self_t = scalar_t;
+      auto seeds = next_philox_seed(gen, 10);
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] {
+        using p_t = scalar_t;
+        return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, seeds);
+      });
    });
   return self;
 }
 
 Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
   AT_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
-  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_scalar_cuda_", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "bernoulli_scalar_cuda_", [&] {
     auto seeds = next_philox_seed(gen, 10);
     bernoulli_scalar_cuda_kernel<scalar_t>(self, p, seeds);
    });
index 6abd9f0..be69cde 100644 (file)
@@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
 }
 
 static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index", [&] {
     using dtype = OpaqueType<sizeof(scalar_t)>;
     index_kernel_impl<dtype>(iter, index_size, index_stride);
   });
@@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR
 
 static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
   AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
-  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index_put", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index_put", [&] {
     using dtype = OpaqueType<sizeof(scalar_t)>;
     index_put_kernel_impl<dtype>(iter, index_size, index_stride);
   });
index 6210b47..227fd81 100644 (file)
@@ -99,7 +99,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
 }
 
 Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(result.type(), "range", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "range", [&]() {
     using accscalar_t = at::acc_type<scalar_t, true>;
     auto xstart = start.to<accscalar_t>();
     auto xend = end.to<accscalar_t>();
@@ -130,7 +130,7 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
 }
 
 Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(result.type(), "arange", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "arange", [&]() {
     using accscalar_t = at::acc_type<scalar_t, true>;
     auto xstart = start.to<accscalar_t>();
     auto xend = end.to<accscalar_t>();
index d3c488e..4018d1d 100644 (file)
@@ -233,14 +233,14 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
     int64_t k,
     int64_t dim,
     bool keepdim) {
-  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "kthvalue", [&] {
     kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
   });
   return std::forward_as_tuple(values, indices);
 }
 
 Tensor median_cuda(const Tensor& self) {
-  return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "median", [&] {
     return median_cuda_template<scalar_t>(self);
   });
 }
index f7612d5..f7d9b68 100644 (file)
@@ -33,7 +33,7 @@ Tensor _s_where_cuda(
     const Tensor& self,
     const Tensor& other) {
   Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_ALL_TYPES_AND_HALF(ret.type(), "where", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.type(), "where", [&] {
     where_cuda<scalar_t>(ret, condition, self, other);
   });
   return ret;
index 58b90bc..4a0a7f1 100644 (file)
@@ -322,7 +322,7 @@ Tensor tril_indices_cuda(
       cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
       "unable to get dim grid");
 
-    AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "tril_indices_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "tril_indices_cuda", [&] {
       tril_indices_kernel<<<
           dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
         tensor.data<scalar_t>(),
@@ -398,7 +398,7 @@ Tensor triu_indices_cuda(
       cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
       "unable to get dim grid");
 
-    AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "triu_indices_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "triu_indices_cuda", [&] {
       triu_indices_kernel<<<
           dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
         tensor.data<scalar_t>(),
index 4eb25de..5fccc10 100644 (file)
@@ -87,7 +87,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
 
   // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work
   if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
-    AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] {
       auto in_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(in_tensor);
       auto out_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(out_tensor);
       int flip_dim = in_tensor_info.collapseDims(flip_dims[0]);
@@ -119,7 +119,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
     }
   }
 
-  AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] {
     flip_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
       in_tensor.data<scalar_t>(), out_tensor.data<scalar_t>(), N, flip_dims_t.toType(CUDA(kLong)).data<int64_t>(), flip_dims_size,
       strides_t.toType(CUDA(kLong)).data<int64_t>(), stride_contiguous.toType(CUDA(kLong)).data<int64_t>(), shape_t.toType(CUDA(kLong)).data<int64_t>(), total_dims);
@@ -177,7 +177,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
 
   auto total_dims = in_tensor.dim();
 
-  AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "roll_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "roll_cuda", [&] {
     roll_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
       in_tensor.data<scalar_t>(), out_tensor.data<scalar_t>(), N,
       dim, start,
index 23e4cfb..fc144d6 100644 (file)
@@ -94,8 +94,8 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
     int64_t stride = at::prod_intlist(values.sizes().slice(1));
     dim3 grid(THCCeilDiv(newNnz, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128));
     dim3 block(32, 4);
-    AT_DISPATCH_ALL_TYPES_AND_HALF(
-        values.type(), "coalesce_sparse_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(
+      at::ScalarType::Half,values.type(), "coalesce_sparse_cuda", [&] {
           using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
           apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
             uniqueOffsets.data<int64_t>(),
index ba0a8de..cac6835 100644 (file)
@@ -295,8 +295,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
     if (sparse.dense_dim() == 0) {
       AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions");
 
-      AT_DISPATCH_ALL_TYPES_AND_HALF(
-          values.type(), "add_out_dense_sparse_cuda", [&] {
+      AT_DISPATCH_ALL_TYPES_AND(
+        at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
             apply::sparseElementwiseKernelScalar<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
               <<<grid, block, 0, stream>>>(
                 TensorCAddOp<scalar_t>(value.to<scalar_t>()),
@@ -309,8 +309,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
       // sparseElementwiseKernel needs values to be contiguous too
       values = values.contiguous();
 
-      AT_DISPATCH_ALL_TYPES_AND_HALF(
-          values.type(), "add_out_dense_sparse_cuda", [&] {
+      AT_DISPATCH_ALL_TYPES_AND(
+        at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
             apply::sparseElementwiseKernel<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
               <<<grid, block, 0, stream>>>(
                 TensorCAddOp<scalar_t>(value.to<scalar_t>()),
@@ -323,8 +323,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
 
     // FIXME: at some point we can wrap the scale into indexAdd
     // NB: Purposely not inplace!
-    AT_DISPATCH_ALL_TYPES_AND_HALF(
-        values.type(), "add_out_dense_sparse_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(
+      at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
           if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
             values = values.mul(value);
           }
@@ -378,8 +378,8 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
   LongTensor s_indices_ = src._indices();
   Tensor s_values_ = src._values();
 
-  AT_DISPATCH_ALL_TYPES_AND_HALF(
-      s_values_.type(), "add_out_sparse_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, s_values_.type(), "add_out_sparse_cuda", [&] {
         if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
           s_values_ = s_values_.mul(value);
         }
@@ -448,8 +448,8 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
   AT_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions");
 
   LongTensor resultNnz = at::empty({1}, CUDA(kLong));
-  AT_DISPATCH_ALL_TYPES_AND_HALF(
-      t_values_.type(), "mul_out_sparse_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(
+    at::ScalarType::Half, t_values_.type(), "mul_out_sparse_cuda", [&] {
         apply::valueSparseIntersectionKernel<TensorMulOp<scalar_t>, uint64_t, scalar_t>
           <<<grid, block, 0, stream>>>(
             TensorMulOp<scalar_t>(),