Change Dispatch.h to use ScalarType over Type
authorRoy Li <royboy@fb.com>
Sat, 9 Mar 2019 00:39:04 +0000 (16:39 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 9 Mar 2019 00:42:04 +0000 (16:42 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17527

Reviewed By: zou3519

Differential Revision: D14235395

fbshipit-source-id: 3f53e33f6794f1f14c2edf79014b8ef8397822c5

82 files changed:
aten/src/ATen/Dispatch.h
aten/src/ATen/detail/ScalarTypeConversions.h
aten/src/ATen/native/Activation.cpp
aten/src/ATen/native/AdaptiveAveragePooling.cpp
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/Copy.cpp
aten/src/ATen/native/Distributions.cpp
aten/src/ATen/native/EmbeddingBag.cpp
aten/src/ATen/native/FractionalMaxPool2d.cpp
aten/src/ATen/native/FractionalMaxPool3d.cpp
aten/src/ATen/native/GridSampler.cpp
aten/src/ATen/native/Lerp.cpp
aten/src/ATen/native/LinearAlgebra.cpp
aten/src/ATen/native/Loss.cpp
aten/src/ATen/native/LossCTC.cpp
aten/src/ATen/native/Normalization.cpp
aten/src/ATen/native/RangeFactories.cpp
aten/src/ATen/native/ReflectionPad.cpp
aten/src/ATen/native/ReplicationPadding.cpp
aten/src/ATen/native/Scalar.cpp
aten/src/ATen/native/SoftMax.cpp
aten/src/ATen/native/Sorting.cpp
aten/src/ATen/native/SummaryOps.cpp
aten/src/ATen/native/TensorCompare.cpp
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/TensorIterator.h
aten/src/ATen/native/TensorTransformations.cpp
aten/src/ATen/native/TypeProperties.cpp
aten/src/ATen/native/Unique.cpp
aten/src/ATen/native/cpu/Activation.cpp
aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
aten/src/ATen/native/cpu/CopyKernel.cpp
aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
aten/src/ATen/native/cpu/GridSamplerKernel.cpp
aten/src/ATen/native/cpu/IndexKernel.cpp
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/cpu/SoftMaxKernel.cpp
aten/src/ATen/native/cpu/TensorCompareKernel.cpp
aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
aten/src/ATen/native/cuda/Activation.cu
aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/cuda/BinaryOpsKernel.cu
aten/src/ATen/native/cuda/CUDAScalar.cu
aten/src/ATen/native/cuda/Copy.cu
aten/src/ATen/native/cuda/DistanceKernel.cu
aten/src/ATen/native/cuda/Distributions.cu
aten/src/ATen/native/cuda/Dropout.cu
aten/src/ATen/native/cuda/Embedding.cu
aten/src/ATen/native/cuda/EmbeddingBag.cu
aten/src/ATen/native/cuda/FractionalMaxPool2d.cu
aten/src/ATen/native/cuda/FractionalMaxPool3d.cu
aten/src/ATen/native/cuda/GridSampler.cu
aten/src/ATen/native/cuda/IndexKernel.cu
aten/src/ATen/native/cuda/Lerp.cu
aten/src/ATen/native/cuda/Loss.cu
aten/src/ATen/native/cuda/LossCTC.cu
aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/RNN.cu
aten/src/ATen/native/cuda/RangeFactories.cu
aten/src/ATen/native/cuda/ReduceOpsKernel.cu
aten/src/ATen/native/cuda/ReflectionPad.cu
aten/src/ATen/native/cuda/ReplicationPadding.cu
aten/src/ATen/native/cuda/SoftMax.cu
aten/src/ATen/native/cuda/SortingKthValue.cu
aten/src/ATen/native/cuda/SpectralOps.cu
aten/src/ATen/native/cuda/SummaryOps.cu
aten/src/ATen/native/cuda/TensorCompare.cu
aten/src/ATen/native/cuda/TensorFactories.cu
aten/src/ATen/native/cuda/TensorTransformations.cu
aten/src/ATen/native/cuda/Unique.cu
aten/src/ATen/native/cuda/WeightNorm.cu
aten/src/ATen/native/mkl/LinearAlgebra.cpp
aten/src/ATen/native/mkl/SpectralOps.cpp
aten/src/ATen/native/sparse/SparseTensor.cpp
aten/src/ATen/native/sparse/SparseTensorMath.cpp
aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu
aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
aten/src/ATen/test/apply_utils_test.cpp
aten/src/ATen/test/scalar_test.cpp
test/test_cpp_extensions.py
torch/csrc/TypeInfo.cpp

index 20d7f2d..51dcae6 100644 (file)
 
 #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                          \
   [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
+    switch (TYPE) {                                                          \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...)                 \
   [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
+    switch (TYPE) {                                                          \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...)              \
   [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
+    switch (TYPE) {                                                          \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(                                                  \
           at::ScalarType::ComplexHalf, std::complex<at::Half>, __VA_ARGS__)  \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)                          \
   [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
+    switch (TYPE) {                                                          \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                               \
   [&] {                                                                      \
-    const at::Type& the_type = TYPE;                                         \
-    switch (the_type.scalarType()) {                                         \
+    switch (TYPE) {                                                          \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
@@ -77,7 +72,7 @@
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
     }                                                                        \
   }()
 
@@ -91,8 +86,7 @@ struct MyTemplate<at::ScalarType::Half> {
 
 #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...)                    \
   [&] {                                                                           \
-    const at::Type& the_type = TYPE;                                              \
-    switch (the_type.scalarType()) {                                              \
+    switch (TYPE) {                                                               \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)            \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)             \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)           \
@@ -102,14 +96,13 @@ struct MyTemplate<at::ScalarType::Half> {
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)           \
       AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
       default:                                                                    \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'");      \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");           \
     }                                                                             \
   }()
 
 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...)        \
   [&] {                                                                           \
-    const at::Type& the_type = TYPE;                                              \
-    switch (the_type.scalarType()) {                                              \
+    switch (TYPE) {                                                               \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)            \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)             \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)           \
@@ -123,6 +116,6 @@ struct MyTemplate<at::ScalarType::Half> {
       AT_PRIVATE_CASE_TYPE(                                                       \
           at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)       \
       default:                                                                    \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'");      \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");           \
     }                                                                             \
   }()
index 76fb0dc..ef04271 100644 (file)
@@ -9,14 +9,14 @@ namespace at { namespace detail {
 
 template <typename T>
 inline T load(const void* data, ScalarType src_type) {
-  return AT_DISPATCH_ALL_TYPES(CPU(src_type), "load", [&]() {
+  return AT_DISPATCH_ALL_TYPES(src_type, "load", [&]() {
     return at::convert<T>(*(scalar_t*)data);
   });
 }
 
 template <typename T>
 inline void store(T value, void* dst, ScalarType dst_type) {
-  AT_DISPATCH_ALL_TYPES(CPU(dst_type), "store", [&]() {
+  AT_DISPATCH_ALL_TYPES(dst_type, "store", [&]() {
     *(scalar_t*)dst = at::convert<scalar_t>(value);
   });
 }
index acbad68..5426c86 100644 (file)
@@ -150,7 +150,7 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) {
 
   // case1: shared weight for all channels
   if (weight_num == 1) {
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] {
       prelu_cpu_kernel_share_weights<scalar_t>(result, input, weight);
     });
   }
@@ -171,7 +171,7 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) {
       "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
       " and channel size = ", channel_size, ".");
 
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] {
       prelu_cpu_kernel_multi_weights<scalar_t>(
         result,
         input,
@@ -277,7 +277,7 @@ std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Ten
 
   // case1: shared parameter for all channels
   if (weight_num == 1) {
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] {
       prelu_cpu_backward_kernel_share_weights<scalar_t>(input, weight, grad_out, input_grad, weight_grad);
     });
   }
@@ -298,7 +298,7 @@ std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Ten
       "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
       " and channel size = ", channel_size, ".");
 
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] {
       prelu_cpu_backward_kernel_multi_weights<scalar_t>(
         input,
         weight,
@@ -326,7 +326,7 @@ std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Ten
 // -----------------------------------
 Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
   auto out_tensor = at::empty_like(self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_cpu", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_cpu", [&] {
     auto lambd_val = lambd.to<scalar_t>();
     at::CPU_tensor_apply2<scalar_t, scalar_t>(
       self,
@@ -342,7 +342,7 @@ Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
 
 Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) {
   auto out_tensor = at::empty_like(self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_backward_cpu", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_backward_cpu", [&] {
     auto lambd_val = lambd.to<scalar_t>();
     at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
       self,
index a31211b..94314f5 100644 (file)
@@ -102,7 +102,7 @@ namespace {
     {
       output.resize_({sizeD, osizeH, osizeW});
 
-      AT_DISPATCH_FLOATING_TYPES(input.type(), "adaptive_avg_pool2d", [&] {
+      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
           auto input_data = input.data<scalar_t>();
           auto output_data = output.data<scalar_t>();
           adaptive_avg_pool2d_out_frame<scalar_t>(input_data, output_data,
@@ -121,7 +121,7 @@ namespace {
     #pragma omp parallel for private(b)
       for (b = 0; b < input.size(0); b++)
       {
-        AT_DISPATCH_FLOATING_TYPES(input.type(), "adaptive_avg_pool2d", [&] {
+        AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
             auto input_data = input.data<scalar_t>();
             auto output_data = output.data<scalar_t>();
             adaptive_avg_pool2d_out_frame<scalar_t>(input_data+b*input.stride(0), output_data+b*sizeD*osizeH*osizeW,
@@ -203,7 +203,7 @@ namespace {
     if (input.ndimension() == 3)
     {
       AT_DISPATCH_FLOATING_TYPES(
-        input.type(), "adaptive_avg_pool2d_backward", [&] {
+        input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] {
           /* get raw pointers */
           scalar_t *gradInput_data = gradInput.data<scalar_t>();
           scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
@@ -223,7 +223,7 @@ namespace {
       for (b = 0; b < input.size(0); b++)
       {
         AT_DISPATCH_FLOATING_TYPES(
-          input.type(), "adaptive_avg_pool2d_backward", [&] {
+          input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] {
             /* get raw pointers */
             scalar_t *gradInput_data = gradInput.data<scalar_t>();
             scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
@@ -262,7 +262,7 @@ namespace {
     return output;
   }
 
-  Tensor adaptive_avg_pool2d(  
+  Tensor adaptive_avg_pool2d(
     at::Tensor const& input,
     IntArrayRef output_size){
     if (output_size[0] == 1 && output_size[1] == 1) {
index b373181..767368a 100644 (file)
@@ -149,13 +149,13 @@ std::tuple<Tensor, Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A)
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
   std::vector<int64_t> infos(batchCount(self), 0);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cpu", [&]{
     apply_gesv<scalar_t>(self_working_copy, A_working_copy, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "gesv");
+    batchCheckErrors(infos, "gesv_cpu");
   } else {
-    singleCheckErrors(infos[0], "gesv");
+    singleCheckErrors(infos[0], "gesv_cpu");
   }
   return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
 }
@@ -172,7 +172,7 @@ std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
 }
 
 std::tuple<Tensor&,Tensor&> gesv_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
-  AT_CHECK(self.dim() == 2 && A.dim() == 2, 
+  AT_CHECK(self.dim() == 2 && A.dim() == 2,
            "torch.gesv() with the `out` keyword does not support batching. "
            "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
   Tensor solution_tmp, lu_tmp;
@@ -229,10 +229,10 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
 Tensor _inverse_helper_cpu(const Tensor& self) {
   std::vector<int64_t> infos(batchCount(self), 0);
   auto self_working_copy = cloneBatchedColumnMajor(self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cpu", [&]{
     apply_inverse<scalar_t>(self_working_copy, infos);
   });
-  batchCheckErrors(infos, "inverse");
+  batchCheckErrors(infos, "inverse_cpu");
   return self_working_copy;
 }
 
@@ -294,13 +294,13 @@ Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool uppe
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
   std::vector<int64_t> infos(batchCount(self), 0);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_solve_cpu", [&]{
     apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "cholesky_solve");
+    batchCheckErrors(infos, "cholesky_solve_cpu");
   } else {
-    singleCheckErrors(infos[0], "cholesky_solve");
+    singleCheckErrors(infos[0], "cholesky_solve_cpu");
   }
   return self_working_copy;
 }
@@ -358,13 +358,13 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos
 Tensor _cholesky_helper_cpu(const Tensor& self, bool upper) {
   std::vector<int64_t> infos(batchCount(self), 0);
   auto self_working_copy = cloneBatchedColumnMajor(self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_cpu", [&]{
     apply_cholesky<scalar_t>(self_working_copy, upper, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "cholesky");
+    batchCheckErrors(infos, "cholesky_cpu");
   } else {
-    singleCheckErrors(infos[0], "cholesky");
+    singleCheckErrors(infos[0], "cholesky_cpu");
   }
   return self_working_copy;
 }
@@ -474,7 +474,7 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) {
   bool inplace = checkTrilTriuBatchContiguous(self);
   Tensor self_c = inplace ? self : self.contiguous();
   Tensor result = inplace ? self : at::empty_like(self);
-  AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "tril", [&]{
     apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
   });
   if (!inplace) self.copy_(result);
@@ -489,7 +489,7 @@ Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
     return result;
   }
   Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
-  AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "tril", [&]{
     apply_triu_tril<scalar_t, false>(result, self_c, false, k);
   });
   return result;
@@ -508,7 +508,7 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) {
   bool inplace = checkTrilTriuBatchContiguous(self);
   Tensor self_c = inplace ? self : self.contiguous();
   Tensor result = inplace ? self : at::empty_like(self);
-  AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "triu", [&]{
     apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
   });
   if (!inplace) self.copy_(result);
@@ -523,7 +523,7 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
     return result;
   }
   Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
-  AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "triu", [&]{
     apply_triu_tril<scalar_t, true>(result, self_c, false, k);
   });
   return result;
index a8c7bab..0ff0b0f 100644 (file)
@@ -20,7 +20,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
 template <typename self_T>
 void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
   AT_CHECK(self.numel() == src.numel(), "sizes do not match");
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cpu", [&]() {
     _copy__cpu<self_T, scalar_t>(self, src);
   });
 }
@@ -43,7 +43,7 @@ Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) {
     return self;
   }
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, self.type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
+    at::ScalarType::Half, self.scalar_type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
   return self;
 }
 
@@ -59,7 +59,7 @@ void _copy_same_type_transpose_(Tensor& self, const Tensor& src) {
   Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
 
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, self.type(), "_copy_same_type_transpose_", [&]() {
+    at::ScalarType::Half, self.scalar_type(), "_copy_same_type_transpose_", [&]() {
         scalar_t* sp = src.data<scalar_t>();
         scalar_t* rp = self.data<scalar_t>();
         scalar_t* bp = buf.data<scalar_t>();
@@ -115,7 +115,7 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) {
 #ifdef _OPENMP
       if (!in_parallel_region()) {
         AT_DISPATCH_ALL_TYPES_AND(
-          at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() {
+          at::ScalarType::Half, self.scalar_type(), "_copy_same_type_", [&]() {
             at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
                 self, src, [](scalar_t& self_val, const scalar_t& src_val) {
                   self_val = src_val;
@@ -134,7 +134,7 @@ void _copy_same_type__cpu(Tensor& self, const Tensor& src) {
 
   if (serial_path) {
     AT_DISPATCH_ALL_TYPES_AND(
-      at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() {
+      at::ScalarType::Half, self.scalar_type(), "_copy_same_type_", [&]() {
         at::CPU_tensor_apply2<scalar_t, scalar_t>(
             self, src, [](scalar_t& self_val, const scalar_t& src_val) {
               self_val = src_val;
index 30ffb76..6ef1848 100644 (file)
@@ -126,7 +126,7 @@ Tensor& bernoulli_out(Tensor& result, const Tensor& self, Generator* gen) {
 }
 
 Tensor& bernoulli_tensor_cpu_(Tensor& self, const Tensor& p_, Generator* gen) {
-  AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_tensor_cpu_self_", [&] {
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
     THGenerator* generator = get_generator(gen);
     std::lock_guard<std::mutex> lock(generator->mutex);
     using self_t = scalar_t;
@@ -137,7 +137,7 @@ Tensor& bernoulli_tensor_cpu_(Tensor& self, const Tensor& p_, Generator* gen) {
           ret_val = static_cast<self_t>(THRandom_bernoulli(generator, p_val));
         });
     } else {
-      AT_DISPATCH_FLOATING_TYPES(p_.type(), "bernoulli_tensor_cpu_p_", [&] {
+      AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
         auto p = std::get<0>(expand_inplace(self, p_.to(kCPU)));
         using p_t = scalar_t;
         CPU_tensor_apply2<self_t, p_t>(
@@ -160,7 +160,7 @@ Tensor& bernoulli_scalar_cpu_(Tensor& self, double p, Generator* gen) {
     return self;
   }
 #endif
-  AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_scalar_cpu_", [&] {
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
     THGenerator* generator = get_generator(gen);
     std::lock_guard<std::mutex> lock(generator->mutex);
     CPU_tensor_apply1<scalar_t>(
@@ -174,7 +174,7 @@ Tensor& bernoulli_scalar_cpu_(Tensor& self, double p, Generator* gen) {
 
 Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
   Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "_standard_gamma_grad", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] {
     CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(ret, self, output,
       [](scalar_t& ret_val, const scalar_t& self_val, const scalar_t &output_val) {
         ret_val = standard_gamma_grad_one<scalar_t, double>(self_val, output_val);
@@ -190,7 +190,7 @@ Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
 
 Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
   Tensor ret = at::zeros(lambda.sizes(), lambda.options());
-  AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] {
+  AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "poisson_cpu", [&] {
     THGenerator* generator = get_generator(gen);
     std::lock_guard<std::mutex> lock(generator->mutex);
     CPU_tensor_apply2<scalar_t, scalar_t>(ret, lambda,
@@ -204,7 +204,7 @@ Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
 
 Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
   Tensor ret = at::zeros(alpha.sizes(), alpha.options());
-  AT_DISPATCH_FLOATING_TYPES(ret.type(), "gamma", [&] {
+  AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] {
     THGenerator* generator = get_generator(gen);
     std::lock_guard<std::mutex> lock(generator->mutex);
     CPU_tensor_apply2<scalar_t, scalar_t>(ret, alpha,
index 11e60e5..0ed1705 100644 (file)
@@ -199,7 +199,7 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
     return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
   } else { // MODE_MAX
     return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      weight.type(), "embedding_bag_cpu_max", [&]() {
+      weight.scalar_type(), "embedding_bag_cpu_max", [&]() {
         return embedding_bag_cpu_max<scalar_t>(weight, indices, offset2bag, output, bag_size, offsets);
       }
     );
index 54e362c..1980306 100644 (file)
@@ -178,7 +178,7 @@ void fractional_max_pool2d_out_cpu_template(
     indices.resize_({numBatch, numPlanes, outputH, outputW});
   }
 
-  AT_DISPATCH_FLOATING_TYPES(input.type(),
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
   "fractional_max_pool2d_out_frame", [&] {
     auto input_data = input.data<scalar_t>();
     auto output_data = output.data<scalar_t>();
@@ -295,7 +295,7 @@ Tensor& fractional_max_pool2d_backward_out_cpu_template(
 
   /* backprop */
   AT_DISPATCH_FLOATING_TYPES(
-    input.type(), "fractional_max_pool2d_backward_out_frame", [&] {
+    input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] {
       auto gradInput_data = gradInput.data<scalar_t>();
       auto gradOutput_data = gradOutput.data<scalar_t>();
       auto indices_data = indices.data<int64_t>();
index 72e6722..30dc2b4 100644 (file)
@@ -201,7 +201,7 @@ void fractional_max_pool3d_out_cpu_template(
     indices.resize_({numBatch, numPlanes, outputT, outputH, outputW});
   }
   AT_DISPATCH_FLOATING_TYPES(
-    input.type(),
+    input.scalar_type(),
     "fractional_max_pool3d_out_frame",
     [&] {
       fractional_max_pool3d_out_frame<scalar_t>(
@@ -330,7 +330,7 @@ void fractional_max_pool3d_backward_out_cpu_template(
 
   /* backprop */
   AT_DISPATCH_FLOATING_TYPES(
-    input.type(),
+    input.scalar_type(),
     "fractional_max_pool3d_backward_out_frame",
     [&]{
       fractional_max_pool3d_backward_out_frame<scalar_t>(
index e9bb623..65af517 100644 (file)
@@ -534,7 +534,7 @@ DEFINE_DISPATCH(grid_sampler_2d_cpu_kernel);
 // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ].
 Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
                            int64_t interpolation_mode, int64_t padding_mode) {
-  return AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler3d_cpu", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler3d_cpu", [&] {
     return grid_sampler_3d_cpu_impl<scalar_t>(
       input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
       static_cast<GridSamplerPadding>(padding_mode));
@@ -554,7 +554,7 @@ DEFINE_DISPATCH(grid_sampler_2d_backward_cpu_kernel);
 std::tuple<Tensor, Tensor>
 grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
                              int64_t interpolation_mode, int64_t padding_mode) {
-  return AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler_3d_backward_cpu", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
     return grid_sampler_3d_backward_cpu_impl<scalar_t>(
       grad_output, input, grid,
       static_cast<GridSamplerInterpolation>(interpolation_mode),
index a541cfe..96f9534 100644 (file)
@@ -38,9 +38,9 @@ Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self,
   Tensor b_self, b_end, b_weight;
   AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
            "weight should be of dimension max(self.dim(), end.dim()) or lesser");
-  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out");
+  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cpu");
   result.resize_as_(b_self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{
     lerp_cpu<scalar_t>(result, b_self, b_end, b_weight);
   });
   return result;
@@ -49,9 +49,9 @@ Tensor& lerp_cpu_tensor_out(Tensor& result, const Tensor& self,
 Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self,
                             const Tensor& end, Scalar weight) {
   Tensor b_self, b_end;
-  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out");
+  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cpu");
   result.resize_as_(b_self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cpu", [&]{
     lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
   });
   return result;
@@ -59,13 +59,13 @@ Tensor& lerp_cpu_scalar_out(Tensor& result, const Tensor& self,
 
 Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
   Tensor b_self, b_end, b_weight;
-  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_");
+  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cpu");
   AT_CHECK(b_self.sizes() == self.sizes(),
            "output with shape ", self.sizes(),
            " doesn't match the broadcast shape ", b_self.sizes());
   AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
            "weight should be of dimension max(self.dim(), end.dim()) or lesser");
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{
     lerp_cpu<scalar_t>(self, b_self, b_end, b_weight);
   });
   return self;
@@ -73,11 +73,11 @@ Tensor& lerp_cpu_tensor_(Tensor& self, const Tensor& end, const Tensor& weight)
 
 Tensor& lerp_cpu_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
   Tensor b_self, b_end;
-  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_");
+  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cpu");
   AT_CHECK(b_self.sizes() == self.sizes(),
            "output with shape ", self.sizes(),
            " doesn't match the broadcast shape ", b_self.sizes());
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cpu", [&]{
     lerp_cpu<scalar_t>(self, b_self, b_end, weight.to<scalar_t>());
   });
   return self;
@@ -87,9 +87,9 @@ Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weig
   Tensor b_self, b_end, b_weight;
   AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
            "weight should be of dimension max(self.dim(), end.dim()) or lesser");
-  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp");
+  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cpu");
   Tensor result = at::empty_like(b_self);
-  AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{
+  AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{
     lerp_cpu<scalar_t>(result, b_self, b_end, b_weight);
   });
   return result;
@@ -97,9 +97,9 @@ Tensor lerp_cpu_tensor(const Tensor& self, const Tensor& end, const Tensor& weig
 
 Tensor lerp_cpu_scalar(const Tensor& self, const Tensor& end, Scalar weight) {
   Tensor b_self, b_end;
-  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp");
+  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cpu");
   Tensor result = at::empty_like(b_self);
-  AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{
+  AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "lerp_cpu", [&]{
     lerp_cpu<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
   });
   return result;
index 42cad7c..c39fcbb 100644 (file)
@@ -299,11 +299,11 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&
 
   if (contraction_size * res_rows * res_cols < 400) {
     if (is_bmm_out) {
-      AT_DISPATCH_ALL_TYPES(batch1.type(), "bmm", [&] {
+      AT_DISPATCH_ALL_TYPES(batch1.scalar_type(), "bmm", [&] {
           baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
         });
     } else {
-      AT_DISPATCH_ALL_TYPES(batch1.type(), "baddbmm", [&] {
+      AT_DISPATCH_ALL_TYPES(batch1.scalar_type(), "baddbmm", [&] {
           baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
         });
     }
index 6722d85..7def0da 100644 (file)
@@ -70,7 +70,7 @@ Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction) {
 Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction) {
   auto grad_input = at::zeros_like(input);
   auto grad_expand = grad.expand_as(input);
-  AT_DISPATCH_FLOATING_TYPES(input.type(), "kl_div_backward", [&]() {
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "kl_div_backward_cpu", [&]() {
     at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
         grad_input,
         target,
index 5b7f376..f6d8906 100644 (file)
@@ -307,7 +307,7 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_
 
 std::tuple<Tensor, Tensor> ctc_loss_cpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
   (void)zero_infinity; // only used for backwards
-  return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cpu", [&] {
       if (targets.scalar_type() == kLong) {
        return ctc_loss_cpu_template<scalar_t, kLong>(log_probs, targets, input_lengths, target_lengths, BLANK);
       } else {
@@ -318,7 +318,7 @@ std::tuple<Tensor, Tensor> ctc_loss_cpu(const Tensor& log_probs, const Tensor& t
 
 Tensor ctc_loss_backward_cpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
                              const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
-  return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss_backward", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cpu", [&] {
       if (targets.scalar_type() == kLong) {
        return ctc_loss_backward_cpu_template<scalar_t,kLong>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
       } else {
index f62bd4b..e4be451 100644 (file)
@@ -486,7 +486,7 @@ Tensor group_norm(const Tensor& input, int64_t num_groups,
 
 std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
         const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
-  return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_update_stats", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm_update_stats_cpu", [&] {
       return batch_norm_cpu_update_stats_template<scalar_t, Var>(self, running_mean, running_var, momentum, 0);
     });
 }
@@ -496,7 +496,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tens
                                                   bool train, double momentum, double eps) {
   checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU);
 
-  return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm", [&] {
       if (!train) {
         return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
       } else {
@@ -509,7 +509,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tens
 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const Tensor& weight,
                                                            const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
                                                            bool train, double eps, std::array<bool,3> grad_input_mask) {
-  return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_backward", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm_backward_cpu", [&] {
       return batch_norm_backward_cpu_template<scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
     });
 }
index 2f71149..02abdab 100644 (file)
@@ -21,7 +21,7 @@ Tensor& linspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps
   } else if (steps == 1) {
     r.fill_(start);
   } else {
-    AT_DISPATCH_FLOATING_TYPES(r.type(), "linspace", [&]() {
+    AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "linspace_cpu", [&]() {
       scalar_t scalar_start = start.to<scalar_t>();
       scalar_t scalar_end = end.to<scalar_t>();
       scalar_t *data_ptr = r.data<scalar_t>();
@@ -54,7 +54,7 @@ Tensor& logspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps
   } else if (steps == 1) {
     r.fill_(std::pow(10.0, start.to<double>()));
   } else {
-    AT_DISPATCH_FLOATING_TYPES(r.type(), "logspace", [&]() {
+    AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "logspace_cpu", [&]() {
       scalar_t base10 = 10;
       scalar_t scalar_start = start.to<scalar_t>();
       scalar_t scalar_end = end.to<scalar_t>();
@@ -76,7 +76,7 @@ Tensor& logspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps
 }
 
 Tensor& range_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
-  AT_DISPATCH_ALL_TYPES(result.type(), "range", [&]() {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "range_cpu", [&]() {
     using accscalar_t = at::acc_type<scalar_t, false>;
     auto xstart = start.to<accscalar_t>();
     auto xend = end.to<accscalar_t>();
@@ -110,7 +110,7 @@ Tensor& range_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
 }
 
 Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
-  AT_DISPATCH_ALL_TYPES(result.type(), "arange", [&]() {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "arange_cpu", [&]() {
     using accscalar_t = at::acc_type<scalar_t, false>;
     auto xstart = start.to<accscalar_t>();
     auto xend = end.to<accscalar_t>();
index 0badbe3..b8ee136 100644 (file)
@@ -91,7 +91,7 @@ void reflection_pad1d_out_template(
   /* resize output */
   if (input.ndimension() == 2) {
     output.resize_({nplane, output_w});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad1d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad1d", [&] {
       reflection_pad1d_out_frame<scalar_t>(
         input.data<scalar_t>(), output.data<scalar_t>(),
         nplane,
@@ -100,7 +100,7 @@ void reflection_pad1d_out_template(
     });
   } else {
     output.resize_({nbatch, nplane, output_w});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad1d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad1d", [&] {
       reflection_pad1d_out_loop<scalar_t>(
         input.data<scalar_t>(), output.data<scalar_t>(),
         nbatch, nplane,
@@ -187,7 +187,7 @@ void reflection_pad1d_backward_out_template(
   /* backprop */
   if (input.ndimension() == 2) {
     AT_DISPATCH_FLOATING_TYPES(
-      grad_input.type(), "reflection_pad1d_backward", [&] {
+      grad_input.scalar_type(), "reflection_pad1d_backward", [&] {
         reflection_pad1d_backward_out_frame(
           grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
           nplane,
@@ -197,7 +197,7 @@ void reflection_pad1d_backward_out_template(
     );
   } else {
     AT_DISPATCH_FLOATING_TYPES(
-      grad_input.type(), "reflection_pad1d_backward", [&] {
+      grad_input.scalar_type(), "reflection_pad1d_backward", [&] {
         reflection_pad1d_backward_out_loop(
           grad_input.data<scalar_t>(),
           grad_output.data<scalar_t>(),
@@ -322,7 +322,7 @@ void reflection_pad2d_out_template(
   if (input.ndimension() == 3) {
     /* resize output */
     output.resize_({nplane, output_h, output_w});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] {
       reflection_pad2d_out_frame(
         input.data<scalar_t>(), output.data<scalar_t>(),
         nplane,
@@ -332,7 +332,7 @@ void reflection_pad2d_out_template(
   } else {
     /* resize output */
     output.resize_({nbatch, nplane, output_h, output_w});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "reflection_pad2d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "reflection_pad2d", [&] {
       reflection_pad2d_out_loop(
         input.data<scalar_t>(), output.data<scalar_t>(),
         nbatch, nplane,
@@ -448,7 +448,7 @@ void reflection_pad2d_backward_out_template(
   /* backprop */
   if (input.ndimension() == 3) {
     AT_DISPATCH_FLOATING_TYPES(
-      grad_output.type(), "reflection_pad2d_backward", [&] {
+      grad_output.scalar_type(), "reflection_pad2d_backward", [&] {
         reflection_pad2d_backward_out_frame(
           grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
           nplane,
@@ -458,7 +458,7 @@ void reflection_pad2d_backward_out_template(
     );
   } else {
     AT_DISPATCH_FLOATING_TYPES(
-      grad_output.type(), "reflection_pad2d_backward", [&] {
+      grad_output.scalar_type(), "reflection_pad2d_backward", [&] {
         reflection_pad2d_backward_out_loop(
           grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
           nbatch, nplane,
index 969dda5..0431cd0 100644 (file)
@@ -97,7 +97,7 @@ void replication_pad1d_out_cpu_template(
   if (input.ndimension() == 2)
   {
     output.resize_({nslices, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad1d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad1d_cpu", [&] {
       auto input_data = input.data<scalar_t>();
       auto output_data = output.data<scalar_t>();
       replication_pad1d_out_frame<scalar_t>(
@@ -113,7 +113,7 @@ void replication_pad1d_out_cpu_template(
   else
   {
     output.resize_({nbatch, nslices, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad1d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad1d_cpu", [&] {
       auto input_data = input.data<scalar_t>();
       auto output_data = output.data<scalar_t>();
       replication_pad1d_out_batch<scalar_t>(
@@ -219,7 +219,7 @@ Tensor& replication_pad1d_backward_out_cpu_template(
   if (input.ndimension() == 2)
   {
     AT_DISPATCH_FLOATING_TYPES(
-      input.type(), "replication_pad1d_backward", [&] {
+      input.scalar_type(), "replication_pad1d_backward_cpu", [&] {
       scalar_t *gradInput_data = gradInput.data<scalar_t>();
       scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
 
@@ -236,7 +236,7 @@ Tensor& replication_pad1d_backward_out_cpu_template(
   else
   {
     AT_DISPATCH_FLOATING_TYPES(
-      input.type(), "replication_pad1d_backward", [&] {
+      input.scalar_type(), "replication_pad1d_backward_cpu", [&] {
       scalar_t *gradInput_data = gradInput.data<scalar_t>();
       scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
 
@@ -365,7 +365,7 @@ void replication_pad2d_out_cpu_template(Tensor& output,
   if (input.dim() == 3)
   {
     output.resize_({nslices, oheight, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad2d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad2d_cpu", [&] {
       auto input_data = input.data<scalar_t>();
       auto output_data = output.data<scalar_t>();
       replication_pad2d_out_frame<scalar_t> (input_data, output_data,
@@ -380,7 +380,7 @@ void replication_pad2d_out_cpu_template(Tensor& output,
   else
   {
     output.resize_({nbatch, nslices, oheight, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad2d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad2d_cpu", [&] {
       auto input_data = input.data<scalar_t>();
       auto output_data = output.data<scalar_t>();
       replication_pad2d_out_batch<scalar_t> (input_data, output_data,
@@ -511,7 +511,7 @@ Tensor& replication_pad2d_backward_out_cpu_template(
   if (input.dim() == 3)
   {
     AT_DISPATCH_FLOATING_TYPES(
-      input.type(), "replication_pad2d_backward", [&] {
+      input.scalar_type(), "replication_pad2d_backward_cpu", [&] {
       replication_pad2d_backward_out_frame<scalar_t>(
         gradInput.data<scalar_t>(),
         gradOutput.data<scalar_t>(),
@@ -526,7 +526,7 @@ Tensor& replication_pad2d_backward_out_cpu_template(
   else
   {
     AT_DISPATCH_FLOATING_TYPES(
-      input.type(), "replication_pad2d_backward", [&] {
+      input.scalar_type(), "replication_pad2d_backward_cpu", [&] {
       replication_pad2d_backward_out_batch<scalar_t>(
         gradInput.data<scalar_t>(),
         gradOutput.data<scalar_t>(),
@@ -709,7 +709,7 @@ void replication_pad3d_out_cpu_template(
   if (input.dim() == 4)
   {
     output.resize_({nslices, odepth, oheight, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad3d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad3d_cpu", [&] {
       auto input_data = input.data<scalar_t>();
       auto output_data = output.data<scalar_t>();
       replication_pad3d_out_frame<scalar_t>(
@@ -722,7 +722,7 @@ void replication_pad3d_out_cpu_template(
   else
   {
     output.resize_({nbatch, nslices, odepth, oheight, owidth});
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "replication_pad3d", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "replication_pad3d_cpu", [&] {
       auto input_data = input.data<scalar_t>();
       auto output_data = output.data<scalar_t>();
       replication_pad3d_out_batch<scalar_t>(
@@ -871,7 +871,7 @@ Tensor& replication_pad3d_backward_out_cpu_template(
   if (input.dim() == 4)
   {
     AT_DISPATCH_FLOATING_TYPES(
-      input.type(), "replication_pad3d_backward", [&] {
+      input.scalar_type(), "replication_pad3d_backward_cpu", [&] {
       replication_pad3d_backward_out_frame<scalar_t> (
         gradInput.data<scalar_t>(),
         gradOutput.data<scalar_t>(),
@@ -887,7 +887,7 @@ Tensor& replication_pad3d_backward_out_cpu_template(
   else
   {
     AT_DISPATCH_FLOATING_TYPES(
-      input.type(), "replication_pad3d_backward", [&] {
+      input.scalar_type(), "replication_pad3d_backward_cpu", [&] {
       replication_pad3d_backward_out_batch<scalar_t> (
         gradInput.data<scalar_t>(),
         gradOutput.data<scalar_t>(),
index c635c9d..94b8df9 100644 (file)
@@ -19,7 +19,7 @@ Scalar item(const Tensor& self) {
 Scalar _local_scalar_dense_cpu(const Tensor& self) {
   Scalar r;
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
-    at::ScalarType::Half, self.type(), "_local_scalar_dense_cpu", [&] {
+    at::ScalarType::Half, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
         scalar_t value = *self.data<scalar_t>();
         r = Scalar(value);
       });
index f34c3e1..64d259f 100644 (file)
@@ -132,7 +132,7 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_
   if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
     softmax_lastdim_kernel(kCPU, output, input);
   } else {
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "softmax", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "softmax", [&] {
       host_softmax<scalar_t, false>(output, input, dim);
     });
   }
@@ -152,7 +152,7 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half
   if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
     log_softmax_lastdim_kernel(kCPU, output, input);
   } else {
-    AT_DISPATCH_FLOATING_TYPES(input.type(), "log_softmax", [&] {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax", [&] {
       host_softmax<scalar_t, true>(output, input, dim);
     });
   }
@@ -181,7 +181,7 @@ Tensor softmax_backward_cpu(
   if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
     softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
   } else {
-    AT_DISPATCH_FLOATING_TYPES(grad.type(), "softmax_backward", [&] {
+    AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
       host_softmax_backward<scalar_t, false>(grad_input, grad, output, dim);
     });
   }
@@ -210,7 +210,7 @@ Tensor log_softmax_backward_cpu(
   if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
     log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
   } else {
-    AT_DISPATCH_FLOATING_TYPES(grad.type(), "log_softmax_backward", [&] {
+    AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] {
       host_softmax_backward<scalar_t, true>(grad_input, grad, output, dim);
     });
   }
index 4bc3c2c..9f52b2a 100644 (file)
@@ -153,7 +153,7 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
   }
   auto tmp_values = self.clone();
   auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong));
-  AT_DISPATCH_ALL_TYPES(self.type(), "kthvalue", [&] {
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "kthvalue_cpu", [&] {
     dim_apply(
         {tmp_values, tmp_indices, values, indices},
         dim,
index 10d7b7b..976dcd7 100644 (file)
@@ -55,7 +55,7 @@ Tensor _bincount_cpu_template(
 
 Tensor
 _bincount_cpu(const Tensor& self, const Tensor& weights, int64_t minlength) {
-  return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] {
+  return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cpu", [&] {
     const auto scalar = weights.scalar_type();
     if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
       return _bincount_cpu_template<scalar_t, float>(self.contiguous(), weights.contiguous(), minlength);
index e59988b..c5360b5 100644 (file)
@@ -91,7 +91,7 @@ Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
 
 Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
   Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_ALL_TYPES(ret.type(), "where", [&] {
+  AT_DISPATCH_ALL_TYPES(ret.scalar_type(), "where_cpu", [&] {
     where_cpu<scalar_t>(ret, condition, self, other);
   });
   return ret;
index 1e712ad..b5d6ec8 100644 (file)
@@ -184,7 +184,7 @@ Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) {
   result.zero_();
 
   int64_t sz = std::min<int64_t>(n, m);
-  AT_DISPATCH_ALL_TYPES(result.type(), "eye", [&]() -> void {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "eye", [&]() -> void {
     scalar_t* result_data = result.data<scalar_t>();
     for(int64_t i = 0; i < sz; i++) {
       result_data[i*(result.strides()[0] + result.strides()[1])] = 1;
@@ -453,7 +453,7 @@ Tensor& randperm_out_cpu(Tensor& result, int64_t n, Generator* generator) {
   AT_CHECK(n >= 0, "n must be non-negative, got", n);
   result.resize_({n});
   auto gen = get_generator(generator);
-  AT_DISPATCH_ALL_TYPES(result.type(), "randperm", [&]() -> void {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "randperm", [&]() -> void {
     randperm_cpu<scalar_t>(result, n, gen);
   });
 
@@ -501,7 +501,7 @@ Tensor tril_indices_cpu(
   //
   // 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor
   //    sequentially, and then transpose it.
-  AT_DISPATCH_ALL_TYPES(result.type(), "tril_indices", [&]() -> void {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tril_indices", [&]() -> void {
     // fill the Tensor with correct values
     scalar_t* result_data = result.data<scalar_t>();
     int64_t i = 0;
@@ -534,7 +534,7 @@ Tensor triu_indices_cpu(
   // create an empty Tensor with correct size
   auto result = at::empty({2, triu_size}, options);
 
-  AT_DISPATCH_ALL_TYPES(result.type(), "triu_indices", [&]() -> void {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void {
     // fill the Tensor with correct values
     scalar_t* result_data = result.data<scalar_t>();
     int64_t i = 0;
@@ -705,7 +705,7 @@ template <typename T>
 Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options) {
   auto result = at::empty(values.size(), options);
   AT_ASSERT(result.is_contiguous());
-  AT_DISPATCH_ALL_TYPES(result.type(), "tensor_cpu", [&] {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "tensor_cpu", [&] {
     std::copy(values.begin(), values.end(), result.template data<scalar_t>());
   });
   return result;
index 7d79d66..274b5e9 100644 (file)
@@ -151,7 +151,7 @@ struct CAFFE2_API TensorIterator {
     AT_ASSERT(operands_[arg].type);
     return *operands_[arg].type;
   }
-  ScalarType dtype(int arg) const { return type(arg).scalarType(); }
+  ScalarType dtype(int arg=0) const { return type(arg).scalarType(); }
   DeviceType device_type(int arg=0) const { return type(arg).device_type(); }
   int64_t element_size(int arg) const { return type(arg).elementSizeInBytes(); }
   bool is_scalar(int arg) const;
index 9f6f0b4..fcb3d0a 100644 (file)
@@ -60,7 +60,7 @@ Tensor flip_cpu(const Tensor& self, IntArrayRef dims) {
     }
   }
 
-  AT_DISPATCH_ALL_TYPES(in_tensor.type(), "flip_cpu", [&] {
+  AT_DISPATCH_ALL_TYPES(in_tensor.scalar_type(), "flip_cpu", [&] {
     flip_cpu_kernel<scalar_t>(
       total_dims,
       stride_contiguous_v,
index 398c599..9ed7648 100644 (file)
@@ -25,7 +25,7 @@ bool is_signed(const Tensor &self) {
   if (self.scalar_type() == ScalarType::Half) {
     return true;
   }
-  return AT_DISPATCH_ALL_TYPES(self.type(), "is_signed", [&]() -> bool {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "is_signed", [&]() -> bool {
     return std::is_signed<scalar_t>();
   });
 }
index 8f6cfae..8cc867f 100644 (file)
@@ -127,14 +127,14 @@ std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
 
 std::tuple<Tensor, Tensor>
 _unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
-  return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] {
     return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
   });
 }
 
 std::tuple<Tensor, Tensor>
 _unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
-  return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
     // The current implementation using `dim` always sorts due to unhashable tensors
     return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
   });
index 1d6e75c..a2e7bd3 100644 (file)
@@ -9,7 +9,7 @@ namespace at { namespace native {
 namespace {
 
 static void threshold_kernel(TensorIterator& iter, Scalar threshold_scalar, Scalar value_scalar) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "threshold", [&] {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "threshold_cpu", [&] {
     using Vec = Vec256<scalar_t>;
     scalar_t threshold = threshold_scalar.to<scalar_t>();
     scalar_t value = value_scalar.to<scalar_t>();
index 2ba95f7..431a6dc 100644 (file)
@@ -14,7 +14,7 @@ namespace {
 using namespace vec256;
 
 void add_kernel(TensorIterator& iter, Scalar alpha_scalar) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "add", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "add_cpu", [&]() {
     auto alpha = alpha_scalar.to<scalar_t>();
     auto alpha_vec = Vec256<scalar_t>(alpha);
     binary_kernel_vec(iter,
@@ -30,7 +30,7 @@ void sub_kernel(TensorIterator& iter, Scalar alpha_scalar) {
 }
 
 void mul_kernel(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "mul", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "mul_cpu", [&]() {
     binary_kernel_vec(iter,
       [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
       [=](Vec256<scalar_t> a, Vec256<scalar_t> b) {
@@ -40,16 +40,16 @@ void mul_kernel(TensorIterator& iter) {
 }
 
 void div_kernel(TensorIterator& iter) {
-  if (isIntegralType(iter.type().scalarType())) {
+  if (isIntegralType(iter.dtype())) {
     // There's no SIMD integer division, so don't try to vectorize it.
     // TODO: if the divisor is a scalar, rewrite as multiplication by a constant.
-    AT_DISPATCH_INTEGRAL_TYPES(iter.type(), "div", [&]() {
+    AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cpu", [&]() {
       binary_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
         return a / b;
       });
     });
   } else {
-    AT_DISPATCH_FLOATING_TYPES(iter.type(), "div", [&]() {
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "div_cpu", [&]() {
       binary_kernel_vec(iter,
         [=](scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
            return a / b;
index 03ef8df..d8cf81d 100644 (file)
@@ -15,7 +15,7 @@ constexpr int64_t COPY_GRAIN_SIZE = 20000;
 
 static void copy_kernel_impl(Tensor& dst, const Tensor& src) {
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, dst.type(), "copy_kernel_impl", [&]() {
+    at::ScalarType::Half, dst.scalar_type(), "copy_kernel_impl", [&]() {
       scalar_t* self_ptr = dst.data<scalar_t>();
       scalar_t* src_ptr = src.data<scalar_t>();
 
index f9ff60c..237d907 100644 (file)
@@ -270,19 +270,19 @@ struct PDist {
 };
 
 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] {
     PDist<scalar_t>::apply(result, self, p);
   });
 }
 
 static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_backward", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] {
     PDist<scalar_t>::apply_backward(result, grad, self, p, dist);
   });
 }
 
 static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
-  AT_DISPATCH_FLOATING_TYPES(result.type(), "cdist", [&] {
+  AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] {
     PDist<scalar_t>::apply_cdist(result, x1, x2, p);
   });
 }
index 7faff9b..d96589d 100644 (file)
@@ -308,8 +308,8 @@ static inline void
 mask_scatter_add(const scalar_t *src, scalar_t* base_addr,
                  const int_same_size_t<scalar_t> *offsets,
                  const int_same_size_t<scalar_t> *mask, int64_t len) {
-  #ifndef _MSC_VER  
-  # pragma unroll  
+  #ifndef _MSC_VER
+  # pragma unroll
   #endif
   for (int64_t i = 0; i < len; i++) {
     if (mask[i] & 0x01) {
@@ -431,8 +431,8 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear, padding>
     auto i_sw_offset = i_nw_offset + iVec(inp_sH);
     auto i_se_offset = i_sw_offset + iVec(inp_sW);
 
-    #ifndef _MSC_VER  
-    # pragma unroll  
+    #ifndef _MSC_VER
+    # pragma unroll
     #endif
     for (int64_t c = 0; c < C; ++c) {
       auto inp_slice_C_ptr = inp_slice[c].data();
@@ -505,8 +505,8 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear, padding>
     scalar_t gInp_corner_arr[Vec::size()];
 
     auto gx = Vec(0), gy = Vec(0);
-    #ifndef _MSC_VER  
-    # pragma unroll  
+    #ifndef _MSC_VER
+    # pragma unroll
     #endif
     for (int64_t c = 0; c < C; ++c) {
       auto inp_slice_C_ptr = inp_slice[c].data();
@@ -598,8 +598,8 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Nearest, padding>
     auto out_ptr = out_slice.data() + offset;
     auto out_sC = out_slice.stride(0);
     auto inp_slice_ptr = inp_slice.data();
-    #ifndef _MSC_VER  
-    # pragma unroll  
+    #ifndef _MSC_VER
+    # pragma unroll
     #endif
     for (int c = 0; c < C; ++c, out_ptr += out_sC, inp_slice_ptr += inp_sC) {
       // mask_gather zeros out the mask, so we need to make a copy
@@ -635,8 +635,8 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Nearest, padding>
     integer_t gInp_offset_arr[iVec::size()];
     i_gInp_offset.store(gInp_offset_arr);
 
-    #ifndef _MSC_VER  
-    # pragma unroll  
+    #ifndef _MSC_VER
+    # pragma unroll
     #endif
     for (int64_t c = 0; c < C; ++c) {
       mask_scatter_add(gOut_slice[c].data() + offset, gInp_slice[c].data(),
@@ -743,15 +743,15 @@ static inline void grid_sample_2d_grid_slice_iterator(
     auto spatial_offset = 0;
     auto i_offsets_delta = iVec(grid_sW * step);
 
-    #ifndef _MSC_VER  
-    # pragma unroll  
+    #ifndef _MSC_VER
+    # pragma unroll
     #endif
     for (int64_t h = 0; h < out_H; h++) {
       auto grid_ptr_x = grid_ptr + h * grid_sH;
       auto grid_ptr_y = grid_ptr_x + grid_sCoor;
       auto i_offsets = iVec::arange(0, grid_sW);
-      #ifndef _MSC_VER  
-      # pragma unroll  
+      #ifndef _MSC_VER
+      # pragma unroll
       #endif
       for (int64_t w = 0; w < out_W; w += step) {
         auto len = std::min(step, out_W - w);
@@ -815,7 +815,7 @@ Tensor grid_sampler_2d_cpu_kernel_impl(const Tensor& input, const Tensor& grid,
     return;                                                            \
   }
 
-  AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler_2d_cpu_kernel_impl", [&] {
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_cpu_kernel_impl", [&] {
     auto out_acc = output.accessor<scalar_t, 4>();
     auto inp_acc = input.accessor<scalar_t, 4>();
     auto grid_acc = grid.accessor<scalar_t, 4>();
@@ -878,7 +878,7 @@ grid_sampler_2d_backward_cpu_kernel_impl(const Tensor& grad_output_,
     return;                                                            \
   }
 
-  AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] {
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] {
     auto gInp_acc = grad_input.accessor<scalar_t, 4>();
     auto gGrid_acc = grad_grid.accessor<scalar_t, 4>();
     auto inp_acc = input.accessor<scalar_t, 4>();
index 0d93112..698b40b 100644 (file)
@@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
 }
 
 void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cpu", [&] {
     cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
       *(scalar_t*)dst = *(scalar_t*)(src + offset);
     });
@@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
 
 void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
   // NOTE: duplicate indices are only supported if accumulate is true.
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index_put", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] {
     if (accumulate) {
       // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
       // this needs to be thread-safe.
index 2f78b59..b895636 100644 (file)
@@ -15,7 +15,7 @@ namespace at { namespace native { namespace {
 using namespace vec256;
 
 static void sum_kernel_impl(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "sum", [&] {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "sum_cpu", [&] {
     binary_kernel_reduce_vec(
       iter,
       [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },
@@ -24,7 +24,7 @@ static void sum_kernel_impl(TensorIterator& iter) {
 }
 
 static void mean_kernel_impl(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&] {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cpu", [&] {
     scalar_t factor = scalar_t(iter.num_output_elements()) / iter.numel();
     binary_kernel_reduce(
       iter,
@@ -35,7 +35,7 @@ static void mean_kernel_impl(TensorIterator& iter) {
 }
 
 static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_sqrt) {
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "std_cpu", [&] {
     binary_kernel_reduce(
       iter,
       WelfordOps<scalar_t, double, int64_t, double> { unbiased, take_sqrt },
@@ -45,7 +45,7 @@ static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_s
 }
 
 static void prod_kernel_impl(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "prod", [&] {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "prod_cpu", [&] {
     binary_kernel_reduce_vec(
       iter,
       [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
@@ -68,7 +68,7 @@ static void norm_kernel_tensor_iterator_impl(
 
 
   if (val == 0) {
-    AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] {
       binary_kernel_reduce(
         iter,
         NormZeroOps<scalar_t>(),
@@ -76,7 +76,7 @@ static void norm_kernel_tensor_iterator_impl(
       );
     });
   } else if (val == 1) {
-    AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] {
       binary_kernel_reduce(
         iter,
         NormOneOps<scalar_t>(),
@@ -84,7 +84,7 @@ static void norm_kernel_tensor_iterator_impl(
       );
     });
   } else if (val == INFINITY) {
-    AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] {
       binary_kernel_reduce(
         iter,
         AbsMaxOps<scalar_t>(),
@@ -92,7 +92,7 @@ static void norm_kernel_tensor_iterator_impl(
       );
     });
   } else if (val == -INFINITY) {
-    AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] {
       binary_kernel_reduce(
         iter,
         AbsMinOps<scalar_t>(),
@@ -100,7 +100,7 @@ static void norm_kernel_tensor_iterator_impl(
       );
     });
   } else {
-    AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&] {
+    AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cpu", [&] {
       binary_kernel_reduce(
         iter,
         NormOps<scalar_t> { scalar_t(val) },
@@ -149,7 +149,7 @@ static void or_kernel_impl(TensorIterator& iter) {
 }
 
 static void min_values_kernel_impl(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cpu", [&iter] {
     binary_kernel_reduce_vec(
       iter,
       [](scalar_t a, scalar_t b) -> scalar_t { return std::min(a, b); },
@@ -158,7 +158,7 @@ static void min_values_kernel_impl(TensorIterator& iter) {
 }
 
 static void max_values_kernel_impl(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cpu", [&iter] {
     binary_kernel_reduce_vec(
       iter,
       [](scalar_t a, scalar_t b) -> scalar_t { return std::max(a, b); },
index 83f0232..d838b86 100644 (file)
@@ -218,7 +218,7 @@ struct vec_host_softmax_backward_lastdim {
 };
 
 static void softmax_lastdim_kernel_impl(Tensor& result, const Tensor& self) {
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "softmax_lastdim_kernel_impl", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "softmax_lastdim_kernel_impl", [&] {
     vec_host_softmax_lastdim<scalar_t, false>::apply(result, self);
   });
 }
@@ -227,7 +227,7 @@ static void log_softmax_lastdim_kernel_impl(
     Tensor& result,
     const Tensor& self) {
   AT_DISPATCH_FLOATING_TYPES(
-      self.type(), "log_softmax_lastdim_kernel_impl", [&] {
+      self.scalar_type(), "log_softmax_lastdim_kernel_impl", [&] {
         vec_host_softmax_lastdim<scalar_t, true>::apply(result, self);
       });
 }
@@ -237,7 +237,7 @@ static void softmax_backward_lastdim_kernel_impl(
     const Tensor& grad,
     const Tensor& output) {
   AT_DISPATCH_FLOATING_TYPES(
-      grad.type(), "softmax_backward_lastdim_kernel_impl", [&] {
+      grad.scalar_type(), "softmax_backward_lastdim_kernel_impl", [&] {
         vec_host_softmax_backward_lastdim<scalar_t, false>::apply(
             grad_input, grad, output);
       });
@@ -248,7 +248,7 @@ static void log_softmax_backward_lastdim_kernel_impl(
     const Tensor& grad,
     const Tensor& output) {
   AT_DISPATCH_FLOATING_TYPES(
-      grad.type(), "log_softmax_backward_lastdim_kernel_impl", [&] {
+      grad.scalar_type(), "log_softmax_backward_lastdim_kernel_impl", [&] {
         vec_host_softmax_backward_lastdim<scalar_t, true>::apply(
             grad_input, grad, output);
       });
index 5dc8603..d118e59 100644 (file)
@@ -97,7 +97,7 @@ static void max_kernel_impl(
     Tensor& max_indices,
     const Tensor& self,
     c10::optional<int64_t> dim) {
-  AT_DISPATCH_ALL_TYPES(self.type(), "max", [&] {
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] {
     Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
   });
 }
@@ -107,7 +107,7 @@ static void min_kernel_impl(
     Tensor& min_indices,
     const Tensor& self,
     c10::optional<int64_t> dim) {
-  AT_DISPATCH_ALL_TYPES(self.type(), "min", [&] {
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] {
     Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
   });
 }
index 31263fd..aaa566c 100644 (file)
@@ -81,7 +81,7 @@ int64_t _sigmoid(double* x, double* y, int64_t size) {
 }
 
 static void sigmoid_kernel(Tensor& result, const Tensor& self) {
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "sigmoid", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "sigmoid", [&] {
     using Vec = Vec256<scalar_t>;
     CPU_tensor_parallel_kernel_apply2<scalar_t, scalar_t>(
         result,
@@ -133,7 +133,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
   int64_t n = self.numel();
   bool contig = self.is_contiguous();
 
-  AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_scalar_cpu_", [&] {
+  AT_DISPATCH_ALL_TYPES(self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
     at::Tensor tmp_int_tensor;
     if (std::is_same<scalar_t, int>::value && contig) {
       tmp_int_tensor = self;
@@ -177,7 +177,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
 #define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op)                          \
   static void op##_kernel(Tensor& result, const Tensor& self) {            \
     checkBackend(#op, {result}, Backend::CPU);                             \
-    AT_DISPATCH_##dispatchtypes##_TYPES(self.type(), #op, [&] {            \
+    AT_DISPATCH_##dispatchtypes##_TYPES(self.scalar_type(), #op, [&] {     \
       if (self.is_contiguous() && result.is_contiguous()) {                \
         vml::v##op(                                                        \
             result.data<scalar_t>(), self.data<scalar_t>(), self.numel()); \
index 83e879a..e0c2148 100644 (file)
@@ -62,7 +62,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) {
 
   // case1: shared weight for all channels
   if (weight_num == 1) {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_cuda", [&] {
       prelu_cuda_kernel_share_weights<scalar_t>(
         input,
         result,
@@ -94,7 +94,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) {
     cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
     AT_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions");
 
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_cuda", [&] {
       prelu_cuda_kernel_multi_weights<scalar_t>
       <<<grid, block, 0, stream>>>(
         result.data<scalar_t>(),
@@ -175,7 +175,7 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
   Tensor weight_grad_collector = at::empty_like(input);
   // case1: shared parameter for all channels
   if (weight_num == 1) {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_backward_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_backward_cuda", [&] {
       prelu_cuda_backward_kernel_share_weights<scalar_t>(
         input,
         grad_out,
@@ -210,7 +210,7 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
     cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
     AT_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions");
 
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "prelu_backward_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "prelu_backward_cuda", [&] {
       prelu_cuda_backward_kernel_multi_weights<scalar_t>
       <<<grid, block, 0, stream>>>(
         input.data<scalar_t>(),
@@ -264,7 +264,7 @@ void hardshrink_backward_cuda_kernel(const Tensor& self, Tensor& out_tensor, sca
 
 Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) {
   auto out_tensor = at::empty_like(self);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "hardshrink_cuda", [&] {
     hardshrink_cuda_kernel<scalar_t>(self, out_tensor, lambd.to<scalar_t>());
   });
   return out_tensor;
@@ -272,7 +272,7 @@ Tensor hardshrink_cuda(const Tensor & self, Scalar lambd) {
 
 Tensor hardshrink_backward_cuda(const Tensor & grad, const Tensor & self, Scalar lambd) {
   auto out_tensor = at::empty_like(grad);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "hardshrink_backward_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "hardshrink_backward_cuda", [&] {
     hardshrink_backward_cuda_kernel<scalar_t>(self, out_tensor, lambd.to<scalar_t>(), grad);
   });
   return out_tensor;
@@ -286,7 +286,7 @@ void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t va
 }
 
 static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "threshold", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "threshold_cuda", [&] {
     threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
   });
 }
index 5828248..7211aa3 100644 (file)
@@ -244,7 +244,7 @@ namespace {
        output.resize_({sizeD, osizeH, osizeW});
     }
     AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-        input_.type(), "adaptive_avg_pool2d", [&] {
+        input_.scalar_type(), "adaptive_avg_pool2d_cuda", [&] {
           scalar_t *input_data = input_.data<scalar_t>();
           scalar_t *output_data = output.data<scalar_t>();
 
@@ -284,13 +284,13 @@ namespace {
 
     int64_t osizeH = gradOutput.size(-2);
     int64_t osizeW = gradOutput.size(-1);
-    
+
     int64_t grid_x = sizeD;
     if (input.ndimension() == 4) grid_x *= input.size(-4);
 
       //bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
     AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-        input.type(), "adaptive_avg_pool2d_backward", [&] {
+        input.scalar_type(), "adaptive_avg_pool2d_backward_cuda", [&] {
           scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
           scalar_t *gradInput_data = gradInput.data<scalar_t>();
 
index 1f0184b..291e0b5 100644 (file)
@@ -260,13 +260,13 @@ std::tuple<Tensor, Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
   std::vector<int64_t> infos(batchCount(self), 0);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cuda", [&]{
     apply_gesv<scalar_t>(self_working_copy, A_working_copy, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "gesv");
+    batchCheckErrors(infos, "gesv_cuda");
   } else {
-    singleCheckErrors(infos[0], "gesv");
+    singleCheckErrors(infos[0], "gesv_cuda");
   }
   return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
 }
@@ -327,11 +327,11 @@ Tensor _inverse_helper_cuda(const Tensor& self) {
   std::vector<int64_t> infos(batchCount(self), 0);
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto self_inv_working_copy = cloneBatchedColumnMajor(self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
     apply_inverse<scalar_t>(
       self_working_copy, self_inv_working_copy, infos);
   });
-  batchCheckErrors(infos, "inverse");
+  batchCheckErrors(infos, "inverse_cuda");
   return self_inv_working_copy;
 }
 
@@ -386,7 +386,7 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp
   int64_t info = 0;
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_solve_cuda", [&]{
     apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, info);
   });
   AT_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info);
@@ -446,13 +446,13 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
     self_working_copy = cloneBatchedColumnMajor(self);
   }
 
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_cuda", [&]{
     apply_cholesky<scalar_t>(self_working_copy, false, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "cholesky");
+    batchCheckErrors(infos, "cholesky_cuda");
   } else {
-    singleCheckErrors(infos[0], "cholesky");
+    singleCheckErrors(infos[0], "cholesky_cuda");
   }
   if (upper) {
     return self_working_copy.transpose(-1, -2);
@@ -493,7 +493,7 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c
           self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
   dim3 dim_block = cuda::getApplyBlock();
   dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), name, [&]{
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), name, [&]{
     triu_tril_kernel<scalar_t, upper>
       <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
         result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
index 5cb1212..2b8e338 100644 (file)
@@ -22,7 +22,7 @@ void add_kernel_impl(TensorIterator& iter, Scalar alpha_scalar) {
 }
 
 static void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "add", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "add_cuda", [&]() {
     add_kernel_impl<scalar_t>(iter, alpha_scalar);
   });
 }
@@ -46,21 +46,21 @@ void div_constant_impl(TensorIterator& iter, scalar_t inv_b) {
 }
 
 static void div_kernel_cuda(TensorIterator& iter) {
-  if (isIntegralType(iter.type().scalarType())) {
-    AT_DISPATCH_INTEGRAL_TYPES(iter.type(), "div", [&]() {
+  if (isIntegralType(iter.dtype())) {
+    AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "div_cuda", [&]() {
       div_kernel_impl<scalar_t>(iter);
     });
   } else if (iter.is_cpu_scalar(2)) {
     // optimization for floating-point types: if the second operand is a CPU
     // scalar, compute a * reciprocal(b). Note that this may lose one bit of
     // precision compared to computing the division.
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "div", [&]() {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "div_cuda", [&]() {
       auto inv_b = scalar_t(1.0 / iter.scalar_value<scalar_t>(2));
       iter.remove_operand(2);
       div_constant_impl<scalar_t>(iter, inv_b);
     });
   } else {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "div", [&]() {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "div_cuda", [&]() {
       div_kernel_impl<scalar_t>(iter);
     });
   }
@@ -74,7 +74,7 @@ void mul_kernel_impl(TensorIterator& iter) {
 }
 
 static void mul_kernel_cuda(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "mul", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "mul_cuda", [&]() {
     mul_kernel_impl<scalar_t>(iter);
   });
 }
index 70b9236..68079ad 100644 (file)
@@ -10,7 +10,7 @@ namespace native {
 Scalar _local_scalar_dense_cuda(const Tensor& self) {
   Scalar r;
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, self.type(), "_local_scalar_dense_cuda", [&] {
+    at::ScalarType::Half, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
         scalar_t value;
         cudaStream_t stream = at::cuda::getCurrentCUDAStream();
         AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
index ab85b48..01c0782 100644 (file)
@@ -169,7 +169,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) {
       cudaMemcpyHostToDevice,
       stream));
   AT_CUDA_CHECK(cudaStreamSynchronize(stream));
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu", [&]() {
     copy_device_to_device<scalar_t, scalar_t>(dst, dst_contig);
   });
 }
@@ -202,7 +202,7 @@ void copy_from_cpu_async_(Tensor& dst, const Tensor& src) {
   CUDAGuard device_guard(dst.device());
   CUDAStream stream = getCurrentCUDAStream();
 
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu_async", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu_async", [&]() {
     AT_CUDA_CHECK(cudaMemcpyAsync(
         dst.data<scalar_t>(),
         src.data<scalar_t>(),
@@ -225,7 +225,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) {
   CUDAGuard device_guard(src.device());
   CUDAStream stream = getCurrentCUDAStream();
 
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_to_cpu_async", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_to_cpu_async", [&]() {
     AT_CUDA_CHECK(cudaMemcpyAsync(
         dst.data<scalar_t>(),
         src.data<scalar_t>(),
@@ -240,7 +240,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) {
 template <typename dst_T>
 void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) {
   AT_CHECK(dst.numel() == src.numel(), "sizes do not match");
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cuda", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cuda", [&]() {
     if (dst.is_cuda() && src.is_cuda()) {
       copy_device_to_device<dst_T, scalar_t>(dst, src);
     } else if (dst.is_cuda()) {
@@ -279,7 +279,7 @@ namespace at {
 namespace native {
 
 Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "_copy__cuda", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "_copy__cuda", [&]() {
     ::_copy__cuda<scalar_t>(self, src, non_blocking);
   });
   return self;
index 37d5c68..8316bde 100644 (file)
@@ -188,7 +188,7 @@ void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, doubl
   const dim3 grid(r1*r2);
   const dim3 block(forward_threads);
 
-  AT_DISPATCH_FLOATING_TYPES(x1.type(), "cdist_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] {
     if (p == 0.0) {
       cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
     } else if (p == 1.0) {
@@ -213,7 +213,7 @@ void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
   const double n2 = n - .5;
   const double n2_squared_minus_1 = n2 * n2 - 1;
 
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] {
     if (p == 0.0) {
       pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), self.data<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
     } else if (p == 1.0) {
@@ -252,7 +252,7 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
   const double n2_squared_minus_1 = n2 * n2 - 1;
 
   Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options());
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "pdist_cuda_backward", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] {
     if (p == 1.0) {
       pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), self.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
     } else if (p < 2.0) {
index c262384..2176547 100644 (file)
@@ -186,7 +186,7 @@ void bernoulli_scalar_cuda_kernel(
 namespace at { namespace native {
 Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
   Tensor ret = at::empty(lambda.sizes(), lambda.options());
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "poisson", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "poisson_cuda", [&] {
     poisson_cuda_kernel<scalar_t>(ret, lambda, next_philox_seed(gen, 20));
   });
   return ret;
@@ -194,7 +194,7 @@ Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
 
 Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
   Tensor ret = at::empty(alpha.sizes(), alpha.options());
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "gamma", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "gamma_cuda", [&] {
      gamma_cuda_kernel<scalar_t>(ret, alpha, next_philox_seed(gen, 10));
    });
   return ret;
@@ -202,7 +202,7 @@ Tensor _s_gamma_cuda(const Tensor& alpha, Generator* gen) {
 
 Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
   Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "_standard_gamma_grad", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "_standard_gamma_grad_cuda", [&] {
      gamma_grad_cuda_kernel<scalar_t>(ret, self, output);
    });
   return ret;
@@ -211,11 +211,11 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
 Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
   auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, self.type(), "bernoulli_tensor_cuda_self_", [&] {
+    at::ScalarType::Half, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
       const at::Type& p_type = p.type();
       using self_t = scalar_t;
       auto seeds = next_philox_seed(gen, 10);
-      AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] {
+      AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] {
         using p_t = scalar_t;
         return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, seeds);
       });
@@ -225,7 +225,7 @@ Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
 
 Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
   AT_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "bernoulli_scalar_cuda_", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "bernoulli_scalar_cuda_", [&] {
     auto seeds = next_philox_seed(gen, 10);
     bernoulli_scalar_cuda_kernel<scalar_t>(self, p, seeds);
    });
index 609a691..16d0b71 100644 (file)
@@ -108,7 +108,7 @@ fused_dropout_cuda(const Tensor& self, double p, Generator * gen){
 //number of times random will be generated per thread, to offset philox counter in thc random state
   int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
   if (cuda::detail::canUse32BitIndexMath(self)){
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "fused_dropout", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
       accscalar_t pa = (accscalar_t)(p);
       auto self_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(self);
@@ -126,7 +126,7 @@ fused_dropout_cuda(const Tensor& self, double p, Generator * gen){
       }
    });
   } else {
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "fused_dropout", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
       accscalar_t pa = (accscalar_t)(p);
       auto self_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(self);
@@ -151,7 +151,7 @@ fused_dropout_cuda(const Tensor& self, double p, Generator * gen){
 Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
    Tensor ret = at::empty_like(self);
    AT_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
-   AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "masked_scale", [&] {
+   AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "masked_scale", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
       accscalar_t pa = (accscalar_t)(scale);
     masked_scale_kernel<scalar_t>(ret, self, mask, pa);
index 44fe18b..8a24923 100644 (file)
@@ -243,7 +243,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
     dim3 block(WARP_SIZE, BLOCKDIMY);
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF
-      (grad.type(),
+      (grad.scalar_type(),
        "embedding_backward",
        [&]
        {
@@ -326,7 +326,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
   dim3 grid(THCCeilDiv(num_indices, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128));
   dim3 block(32, 4);
 
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "embedding_backward", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "embedding_backward", [&] {
     embedding_backward_kernel<<<grid, block, 0, stream>>>(
       sorted_indices.data<int64_t>(),
       orig_indices.data<int64_t>(),
@@ -371,7 +371,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
   dim3 block(128);
   int dim = self.stride(0);
 
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "embedding_backward", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "embedding_backward", [&] {
     using accscalar_t = acc_type<scalar_t, true>;
     renorm_kernel<<<grid, block, 128 * sizeof(accscalar_t), stream>>>(
       self.data<scalar_t>(),
index 4721611..dea987e 100644 (file)
@@ -248,7 +248,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
   dim3 grid(THCCeilDiv(numel, (ptrdiff_t)4), THCCeilDiv(stride, (int64_t)128));
   dim3 block(32, 4);
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      grad.type(), "embedding_bag_backward_cuda_sum_avg_kernel", [&] {
+      grad.scalar_type(), "embedding_bag_backward_cuda_sum_avg_kernel", [&] {
         EmbeddingBag_accGradParametersKernel_sum_avg<
             scalar_t><<<grid, block, 0, stream>>>(
             sorted_indices.data<int64_t>(), orig_indices.data<int64_t>(),
@@ -304,7 +304,7 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
   int grid = 1024;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      grad.type(), "embedding_bag_backward_cuda_max", [&] {
+      grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
         EmbeddingBag_accGradParametersKernel_max<
             scalar_t><<<grid, block, 0, stream>>>(
             max_indices.data<int64_t>(), grad.data<scalar_t>(),
@@ -353,7 +353,7 @@ _embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
 
   dim3 block = dim3(32, 8);
   int grid = 1024;
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(weight.type(), "embedding_bag_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(weight.scalar_type(), "embedding_bag_cuda", [&] {
     EmbeddingBag_updateOutputKernel<scalar_t><<<grid, block, 0, stream>>>(
         indices.data<int64_t>(), offsets.data<int64_t>(),
         weight.data<scalar_t>(), output.data<scalar_t>(),
index e8f6939..7a09514 100644 (file)
@@ -195,7 +195,7 @@ void fractional_max_pool2d_out_cuda_template(
             input_.size(0));
   dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
 
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(),
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
     "fractional_max_pool2d_out_cuda_frame",
     [&] {
       auto devInput = input_.packed_accessor<scalar_t, 4>();
@@ -267,7 +267,7 @@ void fractional_max_pool2d_backward_out_cuda_template(
   dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
 
   auto devIndices = indices.packed_accessor<int64_t, 4>();
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.type(),
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(),
     "fractional_max_pool2d_backward_out_cuda_frame",
     [&] {
       auto devGradInput = gradInput_.packed_accessor<scalar_t, 4>();
index 79015dd..95f9b1a 100644 (file)
@@ -231,7 +231,7 @@ void fractional_max_pool3d_out_cuda_template(
     dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(),
+      input.scalar_type(),
       "fractional_max_pool3d_out_frame",
       [&]{
         fractional_max_pool3d_out_frame<scalar_t>
@@ -321,7 +321,7 @@ void fractional_max_pool3d_backward_out_cuda_template(
     dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      gradOutput.type(),
+      gradOutput.scalar_type(),
       "fractional_max_pool3d_backward_out_frame",
       [&] {
         fractional_max_pool3d_backward_out_frame<scalar_t>
index 706cd44..54cf637 100644 (file)
@@ -791,7 +791,7 @@ Tensor grid_sampler_2d_cuda(const Tensor& input, const Tensor& grid,
   auto output = at::empty({N, input.size(1), H, W}, input.options());
   int count = static_cast<int>(N * H * W);
   if (count > 0) {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_2d_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] {
       grid_sampler_2d_kernel<scalar_t>
         <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
           count,
@@ -815,7 +815,7 @@ Tensor grid_sampler_3d_cuda(const Tensor& input, const Tensor& grid,
   auto output = at::empty({N, input.size(1), D, H, W}, input.options());
   int count = static_cast<int>(N * D * H * W);
   if (count > 0) {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_2d_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] {
       grid_sampler_3d_kernel<scalar_t>
         <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
           count,
@@ -840,7 +840,7 @@ grid_sampler_2d_backward_cuda(const Tensor& grad_output, const Tensor& input, co
   auto grad_grid = at::empty_like(grid);
   int count = static_cast<int>(N * H * W);
   if (count > 0) {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_2d_backward_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
       grid_sampler_2d_backward_kernel<scalar_t>
         <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
           count,
@@ -868,7 +868,7 @@ grid_sampler_3d_backward_cuda(const Tensor& grad_output, const Tensor& input, co
   auto grad_grid = at::empty_like(grid);
   int count = static_cast<int>(N * D * H * W);
   if (count > 0) {
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "grid_sampler_3d_backward_cuda", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
       grid_sampler_3d_backward_kernel<scalar_t>
         <<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
           count,
index be69cde..d498179 100644 (file)
@@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
 }
 
 static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cuda", [&] {
     using dtype = OpaqueType<sizeof(scalar_t)>;
     index_kernel_impl<dtype>(iter, index_size, index_stride);
   });
@@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR
 
 static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
   AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index_put", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] {
     using dtype = OpaqueType<sizeof(scalar_t)>;
     index_put_kernel_impl<dtype>(iter, index_size, index_stride);
   });
index 35a46f1..1946427 100644 (file)
@@ -38,9 +38,9 @@ Tensor& lerp_cuda_tensor_out(Tensor& result, const Tensor& self,
   Tensor b_self, b_end, b_weight;
   AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
            "weight should be of dimension max(self.dim(), end.dim()) or lesser");
-  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out");
+  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_out_cuda");
   result.resize_as_(b_self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cuda", [&]{
     lerp_cuda<scalar_t>(result, b_self, b_end, b_weight);
   });
   return result;
@@ -49,9 +49,9 @@ Tensor& lerp_cuda_tensor_out(Tensor& result, const Tensor& self,
 Tensor& lerp_cuda_scalar_out(Tensor& result, const Tensor& self,
                             const Tensor& end, Scalar weight) {
   Tensor b_self, b_end;
-  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out");
+  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_out_cuda");
   result.resize_as_(b_self);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_out", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_out_cuda", [&]{
     lerp_cuda<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
   });
   return result;
@@ -59,13 +59,13 @@ Tensor& lerp_cuda_scalar_out(Tensor& result, const Tensor& self,
 
 Tensor& lerp_cuda_tensor_(Tensor& self, const Tensor& end, const Tensor& weight) {
   Tensor b_self, b_end, b_weight;
-  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_");
+  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp__cuda");
   AT_CHECK(b_self.sizes() == self.sizes(),
            "output with shape ", self.sizes(),
            " doesn't match the broadcast shape ", b_self.sizes());
   AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
            "weight should be of dimension max(self.dim(), end.dim()) or lesser");
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cuda", [&]{
     lerp_cuda<scalar_t>(self, b_self, b_end, b_weight);
   });
   return self;
@@ -73,11 +73,11 @@ Tensor& lerp_cuda_tensor_(Tensor& self, const Tensor& end, const Tensor& weight)
 
 Tensor& lerp_cuda_scalar_(Tensor& self, const Tensor& end, Scalar weight) {
   Tensor b_self, b_end;
-  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_");
+  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp__cuda");
   AT_CHECK(b_self.sizes() == self.sizes(),
            "output with shape ", self.sizes(),
            " doesn't match the broadcast shape ", b_self.sizes());
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "lerp_", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp__cuda", [&]{
     lerp_cuda<scalar_t>(self, b_self, b_end, weight.to<scalar_t>());
   });
   return self;
@@ -87,9 +87,9 @@ Tensor lerp_cuda_tensor(const Tensor& self, const Tensor& end, const Tensor& wei
   Tensor b_self, b_end, b_weight;
   AT_CHECK(weight.dim() <= std::max(self.dim(), end.dim()),
            "weight should be of dimension max(self.dim(), end.dim()) or lesser");
-  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp");
+  std::tie(b_self, b_end, b_weight) = expand_outplace(self, end, weight, "lerp_cuda");
   Tensor result = at::empty_like(b_self);
-  AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_cuda", [&]{
     lerp_cuda<scalar_t>(result, b_self, b_end, b_weight);
   });
   return result;
@@ -97,9 +97,9 @@ Tensor lerp_cuda_tensor(const Tensor& self, const Tensor& end, const Tensor& wei
 
 Tensor lerp_cuda_scalar(const Tensor& self, const Tensor& end, Scalar weight) {
   Tensor b_self, b_end;
-  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp");
+  std::tie(b_self, b_end) = expand_outplace(self, end, "lerp_cuda");
   Tensor result = at::empty_like(b_self);
-  AT_DISPATCH_FLOATING_TYPES(result.type(), "lerp", [&]{
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lerp_cuda", [&]{
     lerp_cuda<scalar_t>(result, b_self, b_end, weight.to<scalar_t>());
   });
   return result;
index 50678c1..b60de0f 100644 (file)
@@ -28,7 +28,7 @@ namespace at { namespace native {
 Tensor kl_div_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction) {
   auto grad_input = at::zeros_like(input);
   Tensor grad_expand = grad.expand_as(input);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "kl_div_backward", [&]() {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "kl_div_backward_cuda", [&]() {
     kl_div_backward_kernel<scalar_t>(grad_input, target, grad_expand);
   });
   if (reduction == Reduction::Mean) {
index 646a541..547dd6c 100644 (file)
@@ -102,8 +102,7 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
     bool have_three;            // flag which of the two cases in eq (6) we have
     if (s < 2*target_length+1) {
       current_char = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
-      have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) !=
-                               current_char));
+      have_three = ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s-2, BLANK) != current_char));
     } else {
       current_char = BLANK;
       have_three = false;
@@ -631,7 +630,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
 
 std::tuple<Tensor, Tensor> ctc_loss_gpu(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK, bool zero_infinity) {
   (void)zero_infinity; // only used for backward
-  return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_cuda", [&] {
       if (targets.scalar_type() == kLong) {
        return ctc_loss_gpu_template<scalar_t, kLong>(log_probs, targets, input_lengths, target_lengths, BLANK);
       } else {
@@ -642,7 +641,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu(const Tensor& log_probs, const Tensor& t
 
 Tensor ctc_loss_backward_gpu(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths,
                              const Tensor& neg_log_likelihood, const Tensor& log_alpha, int64_t BLANK, bool zero_infinity) {
-  return AT_DISPATCH_FLOATING_TYPES(log_probs.type(), "ctc_loss_backward", [&] {
+  return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "ctc_loss_backward_cuda", [&] {
       if (targets.scalar_type() == kLong) {
        return ctc_loss_backward_gpu_template<scalar_t, kLong>(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, BLANK, zero_infinity);
       } else {
index d77236e..8f92acd 100644 (file)
@@ -4,7 +4,7 @@ namespace at { namespace native {
 
 std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
                                                    const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_cuda", [&] {
       if (cuda::detail::canUse32BitIndexMath(self)) {
         return batch_norm_cuda_template<scalar_t, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
       } else {
@@ -15,7 +15,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const Ten
 
 std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var,
                                                             const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_cuda", [&] {
       if (cuda::detail::canUse32BitIndexMath(self)) {
         return batch_norm_backward_cuda_template<scalar_t, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
       } else {
@@ -25,7 +25,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
 }
 
 std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsilon) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_stats", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_stats_cuda", [&] {
       if (cuda::detail::canUse32BitIndexMath(self)) {
         return batch_norm_stats_cuda_template<scalar_t, int32_t>(self, epsilon);
       } else {
@@ -36,7 +36,7 @@ std::tuple<Tensor, Tensor> batch_norm_stats_cuda(const Tensor& self, double epsi
 
 Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
                              const Tensor& mean, const Tensor& invstd, double epsilon) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_elemt", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_elemt", [&] {
       if (cuda::detail::canUse32BitIndexMath(self)) {
         return batch_norm_elemt_cuda_template<scalar_t, int32_t>(self, weight, bias, mean, invstd, epsilon);
       } else {
@@ -48,7 +48,7 @@ Tensor batch_norm_elemt_cuda(const Tensor& self, const Tensor& weight, const Ten
 // accepting input(self) here to determine template data types, since running_mean/running_var are optional
 std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, const Tensor& mean, const Tensor& invstd, const Tensor& running_mean,
                                                         const Tensor& running_var, double momentum, double epsilon, int64_t count) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_update_stats", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_update_stats_cuda", [&] {
       int world_size = mean.size(1);
       using accscalar_t = at::acc_type<scalar_t, true>;
       if (cuda::detail::canUse32BitIndexMath(self)) {
@@ -61,7 +61,7 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda(const Tensor& self, cons
 
 std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const Tensor& self, const Tensor& input, const Tensor& mean,
                                                                            const Tensor& invstd, bool input_g, bool weight_g, bool bias_g) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_reduce", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_reduce", [&] {
       if (cuda::detail::canUse32BitIndexMath(self)) {
         return batch_norm_backward_reduce_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, input_g, weight_g, bias_g);
       } else {
@@ -72,7 +72,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda(const
 
 Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, const Tensor& mean, const Tensor& invstd,
                                       const Tensor& weight, const Tensor& mean_dy, const Tensor& mean_dy_xmu) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward_elemt", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward_elemt", [&] {
       if (cuda::detail::canUse32BitIndexMath(self)) {
         return batch_norm_backward_elemt_cuda_template<scalar_t, int32_t>(self, input, mean, invstd, weight, mean_dy, mean_dy_xmu);
       } else {
@@ -83,7 +83,7 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c
 
 std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
         const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
-  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "batch_norm_backward", [&] {
       auto mean_st = running_mean.dtype();
       auto var_st = running_var.dtype();
       AT_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types");
index a510060..67b3a7b 100644 (file)
@@ -501,7 +501,7 @@ std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_cuda(
   auto workspace = at::empty_like(input_gates);
   auto hy = at::empty_like(cx);
   auto cy = at::empty_like(cx);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.type(), "_thnn_fused_lstm_cell_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.scalar_type(), "_thnn_fused_lstm_cell_cuda", [&] {
     if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
       lstm_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace);
     } else {
@@ -540,7 +540,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
 
   auto grad_gates = at::empty_like(workspace);
   auto grad_cx = at::empty_like(cx);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(workspace.type(), "_thnn_fused_lstm_cell_cuda_backward", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(workspace.scalar_type(), "_thnn_fused_lstm_cell_cuda_backward", [&] {
     if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
       lstm_backward_impl<scalar_t, int32_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
     } else {
@@ -565,7 +565,7 @@ std::tuple<Tensor, Tensor> _thnn_fused_gru_cell_cuda(
 
   auto workspace = at::empty({hx.size(0), hx.size(1) * GRU_WORKSPACE_MULTIPLIER}, hx.options());
   auto hy = at::empty_like(hx);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.type(), "_thnn_fused_gru_cell_cuda", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_gates.scalar_type(), "_thnn_fused_gru_cell_cuda", [&] {
     if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
       gru_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace);
     } else {
@@ -589,7 +589,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_gru_cell_backward
   auto grad_input_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
   auto grad_hidden_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
   auto grad_hx = at::empty_like(grad_hy);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_hy.type(), "_thnn_fused_gru_cell_cuda_backward", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_hy.scalar_type(), "_thnn_fused_gru_cell_cuda_backward", [&] {
     if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
       gru_backward_impl<scalar_t, int32_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
     } else {
index 227fd81..5bd7d86 100644 (file)
@@ -51,7 +51,7 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
   } else if (steps == 1) {
     r.fill_(start);
   } else {
-    AT_DISPATCH_FLOATING_TYPES(r.type(), "linspace", [&]() {
+    AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "linspace_cuda", [&]() {
       scalar_t scalar_start = start.to<scalar_t>();
       scalar_t scalar_end = end.to<scalar_t>();
       scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1);
@@ -81,7 +81,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
   } else if (steps == 1) {
     r.fill_(std::pow(10.0, start.to<double>()));
   } else {
-    AT_DISPATCH_FLOATING_TYPES(r.type(), "logspace", [&]() {
+    AT_DISPATCH_FLOATING_TYPES(r.scalar_type(), "logspace_cuda", [&]() {
       scalar_t scalar_start = start.to<scalar_t>();
       scalar_t scalar_end = end.to<scalar_t>();
       scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1);
@@ -99,7 +99,7 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
 }
 
 Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "range", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "range_cuda", [&]() {
     using accscalar_t = at::acc_type<scalar_t, true>;
     auto xstart = start.to<accscalar_t>();
     auto xend = end.to<accscalar_t>();
@@ -130,7 +130,7 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
 }
 
 Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "arange", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.scalar_type(), "arange_cuda", [&]() {
     using accscalar_t = at::acc_type<scalar_t, true>;
     auto xstart = start.to<accscalar_t>();
     auto xend = end.to<accscalar_t>();
index 3a47902..715411d 100644 (file)
@@ -41,7 +41,7 @@ void prod_kernel_impl(TensorIterator& iter) {
 }
 
 static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&]() {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "std", [&]() {
     std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
   });
 }
@@ -77,46 +77,46 @@ void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) {
 }
 
 static void sum_kernel_cuda(TensorIterator& iter) {
-  if (iter.type().scalarType() == kHalf) {
+  if (iter.dtype() == kHalf) {
     return sum_kernel_impl<at::Half, float>(iter);
-  } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+  } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
     // type promotion that does cast and reduction in a single kernel
     return sum_kernel_impl<at::Half, float, float>(iter);
   }
-  AT_DISPATCH_ALL_TYPES(iter.type(), "sum", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "sum_cuda", [&]() {
     sum_kernel_impl<scalar_t>(iter);
   });
 }
 
 static void prod_kernel_cuda(TensorIterator& iter) {
-  if (iter.type().scalarType() == kHalf) {
+  if (iter.dtype() == kHalf) {
     return prod_kernel_impl<at::Half, float>(iter);
   }
-  AT_DISPATCH_ALL_TYPES(iter.type(), "prod", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "prod_cuda", [&]() {
     prod_kernel_impl<scalar_t>(iter);
   });
 }
 
 static void mean_kernel_cuda(TensorIterator& iter) {
-  if (iter.type().scalarType() == kHalf) {
+  if (iter.dtype() == kHalf) {
     return mean_kernel_impl<at::Half, float>(iter);
-  } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+  } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
     // type promotion that does cast and reduction in a single kernel
     return mean_kernel_impl<at::Half, float, float>(iter);
   }
-  AT_DISPATCH_ALL_TYPES(iter.type(), "mean", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "mean_cuda", [&]() {
     mean_kernel_impl<scalar_t>(iter);
   });
 }
 
 static void norm_kernel_cuda(TensorIterator& iter, Scalar p) {
-  if (iter.type().scalarType() == kHalf) {
+  if (iter.dtype() == kHalf) {
     return norm_kernel_cuda_impl<at::Half, float>(iter, p);
-  } else if (iter.type(1).scalarType() == kHalf && iter.type().scalarType() == kFloat) {
+  } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
     // type promotion that does cast and reduction in a single kernel
     return norm_kernel_cuda_impl<at::Half, float, float>(iter, p);
   }
-  AT_DISPATCH_FLOATING_TYPES(iter.type(), "norm", [&]() {
+  AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "norm_cuda", [&]() {
     norm_kernel_cuda_impl<scalar_t>(iter, p);
   });
 }
@@ -152,13 +152,13 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) {
 }
 
 void max_values_kernel_cuda(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "max_values", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "max_values_cuda", [&]() {
     max_values_kernel_cuda_impl<scalar_t>(iter);
   });
 }
 
 void min_values_kernel_cuda(TensorIterator& iter) {
-  AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&]() {
+  AT_DISPATCH_ALL_TYPES(iter.dtype(), "min_values_cuda", [&]() {
     min_values_kernel_cuda_impl<scalar_t>(iter);
   });
 }
index 5e6e6bf..0a1cacf 100644 (file)
@@ -191,7 +191,7 @@ void reflection_pad1d_out_template(
   Tensor input = input_.contiguous();
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-    input.type(), "reflection_pad1d_out_template", [&] {
+    input.scalar_type(), "reflection_pad1d_out_template", [&] {
       reflection_pad1d_out_kernel<<<
         grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
           input.data<scalar_t>(), output.data<scalar_t>(),
@@ -239,7 +239,7 @@ void reflection_pad1d_backward_out_template(
   dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-    grad_input.type(), "reflection_pad1d_backward_out_template", [&] {
+    grad_input.scalar_type(), "reflection_pad1d_backward_out_template", [&] {
       reflection_pad1d_backward_out_kernel<<<
         grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
           grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
@@ -311,7 +311,7 @@ void reflection_pad2d_out_template(
     (int) std::ceil(output_plane_size/256.0), nplane, nbatch);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-    input.type(), "reflection_pad2d_out_template", [&] {
+    input.scalar_type(), "reflection_pad2d_out_template", [&] {
       reflection_pad2d_out_kernel<<<
         grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
           input.data<scalar_t>(), output.data<scalar_t>(),
@@ -368,7 +368,7 @@ void reflection_pad2d_backward_out_template(
     (int) std::ceil(output_plane_size/256.0), nplane, nbatch);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-    input.type(), "reflection_pad2d_backward_out_template", [&] {
+    input.scalar_type(), "reflection_pad2d_backward_out_template", [&] {
       reflection_pad2d_backward_out_kernel<<<
         grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>(
           grad_input.data<scalar_t>(), grad_output.data<scalar_t>(),
index a9790df..867ebf2 100644 (file)
@@ -235,7 +235,7 @@ void replication_pad1d_out_cuda_template(
 
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(), "replication_pad1d", [&] {
+      input.scalar_type(), "replication_pad1d_cuda", [&] {
 
 
       if (numInputDims == 2) {
@@ -306,7 +306,7 @@ void replication_pad1d_backward_out_cuda_template(
   gradInput.zero_();
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(), "replication_pad1d_backward", [&] {
+      input.scalar_type(), "replication_pad1d_backward_cuda", [&] {
 
       auto gradInput_ = gradInput;
       auto gradOutput_ = gradOutput;
@@ -372,7 +372,7 @@ void replication_pad2d_out_cuda_template(
       " Calculated output H: ", outputH, " W: ", outputW);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(), "replication_pad2d", [&] {
+      input.scalar_type(), "replication_pad2d_cuda", [&] {
 
 
       if (numInputDims == 3) {
@@ -403,7 +403,7 @@ void replication_pad2d_out_cuda_template(
         dim3 blockSize(outputPlaneSize > 256 ? 256 : outputPlaneSize);
 
         replication_pad_forward_kernel2d <<<gridSize, blockSize, 0,
-                                       at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, 
+                                       at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput,
                                            padT, padB, padL, padR);
       }
       }
@@ -427,7 +427,7 @@ void replication_pad2d_backward_out_cuda_template(
   int padL = paddingSize[0];
   int padR = paddingSize[1];
   int padT = paddingSize[2];
-  int padB = paddingSize[3]; 
+  int padB = paddingSize[3];
   int planeDim = 0;
   int dimh = 1;
   int dimw = 2;
@@ -454,7 +454,7 @@ void replication_pad2d_backward_out_cuda_template(
   gradInput.zero_();
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(), "replication_pad2d_backward", [&] {
+      input.scalar_type(), "replication_pad2d_backward_cuda", [&] {
 
         auto gradInput_ = gradInput;
         auto gradOutput_ = gradOutput;
@@ -483,7 +483,7 @@ static inline void shapeCheck3d(
     int pleft, int pright,
     int ptop, int pbottom,
     int pfront, int pback) {
-  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), 
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
       "input tensor must fit into 32-bit index math");
   int numInputDims = input.dim();
 
@@ -521,7 +521,7 @@ static inline void shapeAndGradOutputCheck3d(
     int pleft, int pright,
     int ptop, int pbottom,
     int pfront, int pback) {
-  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input), 
+  AT_CHECK(at::cuda::detail::canUse32BitIndexMath(input),
       "input tensor must fit into 32-bit index math");
   int numInputDims = input.dim();
 
@@ -579,7 +579,7 @@ void replication_pad3d_out_cuda_template(
   int ptop = paddingSize[2];
   int pbottom = paddingSize[3];
   int pfront = paddingSize[4];
-  int pback = paddingSize[5]; 
+  int pback = paddingSize[5];
   shapeCheck3d(input, pleft, pright, ptop,
       pbottom, pfront, pback);
 
@@ -608,7 +608,7 @@ void replication_pad3d_out_cuda_template(
   int outputW  = inputW + pleft + pright;
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(), "replication_pad3d", [&] {
+      input.scalar_type(), "replication_pad3d_cuda", [&] {
 
       if (numInputDims == 4) {
         output.resize_({numPlanes, outputD, outputH, outputW});
@@ -660,7 +660,7 @@ void replication_pad3d_backward_out_cuda_template(
   int ptop = paddingSize[2];
   int pbottom = paddingSize[3];
   int pfront = paddingSize[4];
-  int pback = paddingSize[5]; 
+  int pback = paddingSize[5];
   shapeAndGradOutputCheck3d(input, gradOutput, pleft, pright, ptop,
       pbottom, pfront, pback);
 
@@ -681,7 +681,7 @@ void replication_pad3d_backward_out_cuda_template(
   gradInput.zero_();
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
-      input.type(), "replication_pad3d_backward", [&] {
+      input.scalar_type(), "replication_pad3d_backward_cuda", [&] {
 
       auto gradInput_ = gradInput;
       auto gradOutput_ = gradOutput;
index 8808272..ef3031e 100644 (file)
@@ -499,7 +499,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
       const int ILP = 2;
       dim3 grid(outer_size);
       dim3 block = SoftMax_getBlockSize(ILP, dim_size);
-      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "host_softmax", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
       if (!half_to_float) {
           cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
@@ -519,7 +519,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
     } else {
       uint32_t smem_size;
       dim3 grid, block;
-      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "host_softmax", [&] {
       using accscalar_t = acc_type<scalar_t, true>;
       if (!half_to_float) {
           SpatialSoftMax_getLaunchSizes<accscalar_t>(
@@ -573,7 +573,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t
     const int ILP = 2;
     dim3 grid(outer_size);
     dim3 block = SoftMax_getBlockSize(ILP, dim_size);
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(gI.type(), "host_softmax_backward", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(gI.scalar_type(), "host_softmax_backward", [&] {
     using accscalar_t = acc_type<scalar_t, true>;
     if (!half_to_float) {
         cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
@@ -590,7 +590,7 @@ Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t
   } else {
     uint32_t smem_size;
     dim3 grid, block;
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "host_softmax_backward", [&] {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "host_softmax_backward", [&] {
     using accscalar_t = acc_type<scalar_t, true>;
     if (!half_to_float) {
         SpatialSoftMax_getLaunchSizes<accscalar_t>(
index 4018d1d..50d0b90 100644 (file)
@@ -233,14 +233,14 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cuda(
     int64_t k,
     int64_t dim,
     bool keepdim) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "kthvalue", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "kthvalue_cuda", [&] {
     kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
   });
   return std::forward_as_tuple(values, indices);
 }
 
 Tensor median_cuda(const Tensor& self) {
-  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "median", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "median", [&] {
     return median_cuda_template<scalar_t>(self);
   });
 }
index 6a79f66..19c28ff 100644 (file)
@@ -114,7 +114,7 @@ static void _fft_fill_with_conjugate_symmetry_(Tensor& input,
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
   auto policy = thrust::cuda::par(allocator).on(stream);
-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "_fft_fill_with_conjugate_symmetry_", [&] {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "_fft_fill_with_conjugate_symmetry_", [&] {
     typedef thrust::device_ptr<scalar_t> device_ptr;
     typedef thrust::counting_iterator<int64_t> counter;
     typedef thrust::transform_iterator<cnt_to_dst_idx_functor, counter> dst_idx_iterator;
index 3903d9b..d2c7be5 100644 (file)
@@ -330,7 +330,7 @@ Tensor _bincount_cuda(
     const Tensor& self,
     const Tensor& weights,
     int64_t minlength) {
-  return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] {
+  return AT_DISPATCH_INTEGRAL_TYPES(self.scalar_type(), "bincount_cuda", [&] {
     const auto scalar = weights.scalar_type();
     if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
       return _bincount_cuda_template<scalar_t, float>(self, weights, minlength);
@@ -347,7 +347,7 @@ Tensor _histc_cuda(
   if (self.scalar_type() == ScalarType::Half) {
     AT_ERROR("HalfTensor is not supported");
   }
-  return AT_DISPATCH_ALL_TYPES(self.type(), "histc", [&] {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "histc", [&] {
     return _histc_cuda_template<scalar_t>(self, nbins, min.to<scalar_t>(), max.to<scalar_t>());
   });
 }
index f7d9b68..d4ccc70 100644 (file)
@@ -33,7 +33,7 @@ Tensor _s_where_cuda(
     const Tensor& self,
     const Tensor& other) {
   Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.type(), "where", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.scalar_type(), "where_cuda", [&] {
     where_cuda<scalar_t>(ret, condition, self, other);
   });
   return ret;
index f0c32cc..c9bb377 100644 (file)
@@ -90,7 +90,7 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
     } else {
       // Generate random values for the keys array
       AT_DISPATCH_ALL_TYPES(
-        result.type(), "randperm_out_cuda", [&] {
+        result.scalar_type(), "randperm_out_cuda", [&] {
           auto keys = at::empty(result.sizes(), result.options()).random_(generator);
 
           auto result_data = thrust::device_ptr<scalar_t>(result.data<scalar_t>());
@@ -322,7 +322,7 @@ Tensor tril_indices_cuda(
       cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
       "unable to get dim grid");
 
-    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "tril_indices_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.scalar_type(), "tril_indices_cuda", [&] {
       tril_indices_kernel<<<
           dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
         tensor.data<scalar_t>(),
@@ -398,7 +398,7 @@ Tensor triu_indices_cuda(
       cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
       "unable to get dim grid");
 
-    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "triu_indices_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.scalar_type(), "triu_indices_cuda", [&] {
       triu_indices_kernel<<<
           dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
         tensor.data<scalar_t>(),
index 5fccc10..00f50f1 100644 (file)
@@ -87,7 +87,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
 
   // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work
   if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
-    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] {
+    AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] {
       auto in_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(in_tensor);
       auto out_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(out_tensor);
       int flip_dim = in_tensor_info.collapseDims(flip_dims[0]);
@@ -119,7 +119,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
     }
   }
 
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "flip_cuda", [&] {
     flip_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
       in_tensor.data<scalar_t>(), out_tensor.data<scalar_t>(), N, flip_dims_t.toType(CUDA(kLong)).data<int64_t>(), flip_dims_size,
       strides_t.toType(CUDA(kLong)).data<int64_t>(), stride_contiguous.toType(CUDA(kLong)).data<int64_t>(), shape_t.toType(CUDA(kLong)).data<int64_t>(), total_dims);
@@ -177,7 +177,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
 
   auto total_dims = in_tensor.dim();
 
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "roll_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.scalar_type(), "roll_cuda", [&] {
     roll_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
       in_tensor.data<scalar_t>(), out_tensor.data<scalar_t>(), N,
       dim, start,
index 828fb48..0ba6812 100644 (file)
@@ -146,7 +146,7 @@ template <typename scalar_t>
 
 std::tuple<Tensor, Tensor>
 _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
-  return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
     return _unique_cuda_template<scalar_t>(self, return_inverse);
@@ -155,7 +155,7 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
 
 std::tuple<Tensor, Tensor>
 _unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
-  return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+  return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
     return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
   });
 }
index f8e9c14..76f4272 100644 (file)
@@ -7,7 +7,7 @@
 #include <THC/THCDeviceUtils.cuh>
 #include <THC/THCTensorMathReduce.cuh>
 
-namespace at { 
+namespace at {
 namespace native {
 namespace {
 
@@ -15,15 +15,15 @@ namespace {
 // Currently, kernels are non-persistent.
 // Dialing up the block size to, say 1024, can improve performance by
 // increase the amount of cache available per block, which can improve cache hit rate.
-// However, this is less efficient for short rows.  256 is pretty versatile. 
+// However, this is less efficient for short rows.  256 is pretty versatile.
 // May be worth implementing heuristics later.
 #define BLOCK 256
 
 // Block size for weight_norm_*_last_dim_kernel.
-// This is tricker than the first_dim case because we must make blocks 
+// This is tricker than the first_dim case because we must make blocks
 // at least 16 fast elements wide to ensure fully-coalesced half-precision accesses.
-// Since output-element parallelism is along the fast dimension, this reduces the number of 
-// blocks we can launch by 16X.  
+// Since output-element parallelism is along the fast dimension, this reduces the number of
+// blocks we can launch by 16X.
 #define TILE_W 16
 // Somewhat versatile strategy: max out intra-block parallelism by extending
 // blocks across the slow dimension up to the hardware-max block size of 1024.
@@ -31,11 +31,11 @@ namespace {
 
 template<typename T, typename ReduceOp>
 __device__ __forceinline__ void reduce_block_into_lanes
-  (T *x, 
-   T val, 
+  (T *x,
+   T val,
    int lanes, // lanes is intended to be <= 32.
-   ReduceOp reduceOp) 
-{ 
+   ReduceOp reduceOp)
+{
   int tid = threadIdx.x + threadIdx.y*blockDim.x;
   int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
 
@@ -44,16 +44,16 @@ __device__ __forceinline__ void reduce_block_into_lanes
     x[tid] = val;
     __syncthreads();
   }
-  
+
   #pragma unroll
-  for(int i = (blockSize >> 1); i >= 64; i >>= 1) 
+  for(int i = (blockSize >> 1); i >= 64; i >>= 1)
   {
     if(tid < i)
       x[tid] = reduceOp(x[tid], x[tid+i]);
     __syncthreads();
   }
 
-  if(tid < 32) 
+  if(tid < 32)
   {
     T final;
     if(blockSize >= 64)
@@ -66,7 +66,7 @@ __device__ __forceinline__ void reduce_block_into_lanes
     for(int i = 16; i >= lanes; i >>= 1)
       final = reduceOp(final, WARP_SHFL_DOWN(final, i));
 
-    if(tid < lanes) 
+    if(tid < lanes)
       x[tid] = final; // EpilogueOp
   }
 
@@ -75,14 +75,14 @@ __device__ __forceinline__ void reduce_block_into_lanes
 }
 
 template
-  <typename scalar_t, 
+  <typename scalar_t,
    typename accscalar_t>
 __global__ void weight_norm_fwd_first_dim_kernel
   (scalar_t* __restrict__ w,
    accscalar_t* __restrict__ norms,
    const scalar_t* __restrict__ v,
    const scalar_t* __restrict__ g,
-   const int rowSize) 
+   const int rowSize)
 {
   // We are norming each slowest-dim row of the tensor separately.
   // For now, assign one block to each row.
@@ -98,11 +98,11 @@ __global__ void weight_norm_fwd_first_dim_kernel
   // extern __shared__ accscalar_t s[]; // error: declaration is incompatible with previous "s"
   extern __shared__ char buf[];
   accscalar_t* s = (accscalar_t*)buf;
-  
+
   accscalar_t thread_sum = 0.f;
-  for(int i = tid; i < rowSize; i += stride ) 
+  for(int i = tid; i < rowSize; i += stride )
   {
-    accscalar_t val_f = scalar_cast<accscalar_t>(v[i+rowStart]); 
+    accscalar_t val_f = scalar_cast<accscalar_t>(v[i+rowStart]);
     thread_sum += val_f*val_f; // AccumOp, could do Kahan here
   }
 
@@ -110,7 +110,7 @@ __global__ void weight_norm_fwd_first_dim_kernel
   accscalar_t result = s[0];
 
   result = sqrtf(result);
-  
+
   if(tid == 0)
     norms[row] = result;
 
@@ -120,7 +120,7 @@ __global__ void weight_norm_fwd_first_dim_kernel
   accscalar_t rnorm = 1.f/result; // for consistency with backward kernel
 
   // Write data to output
-  for(int i = tid; i < rowSize; i += stride ) 
+  for(int i = tid; i < rowSize; i += stride )
   {
     accscalar_t val_f = scalar_cast<accscalar_t>(v[i+rowStart]);
     w[i+rowStart] = scalar_cast<scalar_t>(g_this_row*val_f*rnorm);
@@ -128,7 +128,7 @@ __global__ void weight_norm_fwd_first_dim_kernel
 }
 
 template
-  <typename scalar_t, 
+  <typename scalar_t,
    typename accscalar_t>
 __global__ void weight_norm_fwd_last_dim_kernel
 (
@@ -154,13 +154,13 @@ __global__ void weight_norm_fwd_last_dim_kernel
   if(fast_dim_location < fast_dim_size)
     while(slower_dims_location < slower_dims_size)
     {
-      accscalar_t val_f = scalar_cast<accscalar_t>(v[currentIdx]); 
+      accscalar_t val_f = scalar_cast<accscalar_t>(v[currentIdx]);
       thread_sum += val_f*val_f; // AccumOp, could do Kahan here
       currentIdx += blockDim.y*fast_dim_size;
-      slower_dims_location += blockDim.y; 
+      slower_dims_location += blockDim.y;
     }
 
-  reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>()); 
+  reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>());
 
   // Better to pass an EpilogueOp to reduce_block_into_lanes?
   if(threadIdx.y == 0)
@@ -170,26 +170,26 @@ __global__ void weight_norm_fwd_last_dim_kernel
     norms[fast_dim_location] = norm_this_col;
     rnorms_this_block[threadIdx.x] = 1.f/norm_this_col;
   }
-   
-  __syncthreads(); 
 
-  accscalar_t g_this_col = scalar_cast<accscalar_t>(g[fast_dim_location]);     
-  accscalar_t rnorm = rnorms_this_block[threadIdx.x]; 
+  __syncthreads();
+
+  accscalar_t g_this_col = scalar_cast<accscalar_t>(g[fast_dim_location]);
+  accscalar_t rnorm = rnorms_this_block[threadIdx.x];
 
   slower_dims_location = threadIdx.y;
   currentIdx = fast_dim_location + fast_dim_size*slower_dims_location;
   if(fast_dim_location < fast_dim_size)
     while(slower_dims_location < slower_dims_size)
     {
-      accscalar_t val_f = scalar_cast<accscalar_t>(v[currentIdx]); 
+      accscalar_t val_f = scalar_cast<accscalar_t>(v[currentIdx]);
       w[currentIdx] = scalar_cast<scalar_t>(g_this_col*val_f*rnorm);
       currentIdx += blockDim.y*fast_dim_size;
-      slower_dims_location += blockDim.y; 
-    } 
+      slower_dims_location += blockDim.y;
+    }
 }
 
 template
-  <typename scalar_t, 
+  <typename scalar_t,
    typename accscalar_t>
 __global__ void weight_norm_bwd_first_dim_kernel
   (scalar_t* __restrict__ grad_v,
@@ -213,12 +213,12 @@ __global__ void weight_norm_bwd_first_dim_kernel
   // extern __shared__ accscalar_t s[]; // error: declaration is incompatible with previous "s"
   extern __shared__ char buf[];
   accscalar_t* s = (accscalar_t*)buf;
-  
+
   accscalar_t thread_sum = 0.f;
-  for(int i = tid; i < rowSize; i += stride ) 
+  for(int i = tid; i < rowSize; i += stride )
   {
-    accscalar_t grad_wi = scalar_cast<accscalar_t>(grad_w[i+rowStart]); 
-    accscalar_t saved_vi = scalar_cast<accscalar_t>(saved_v[i+rowStart]); 
+    accscalar_t grad_wi = scalar_cast<accscalar_t>(grad_w[i+rowStart]);
+    accscalar_t saved_vi = scalar_cast<accscalar_t>(saved_v[i+rowStart]);
     thread_sum += grad_wi*saved_vi; // AccumOp, could do Kahan here
   }
 
@@ -228,7 +228,7 @@ __global__ void weight_norm_bwd_first_dim_kernel
   // Could choose to save reciprocal of norm instead I suppose, but norms is probably
   // more handy to keep around.
   // Broadcast load; could use shared memory instead.
-  accscalar_t rnorm = 1.f/saved_norms[row];  
+  accscalar_t rnorm = 1.f/saved_norms[row];
   accscalar_t rnorm3 = rnorm*rnorm*rnorm;
 
   // Write g gradients.
@@ -237,20 +237,20 @@ __global__ void weight_norm_bwd_first_dim_kernel
 
   // Broadcast load, could use shared memory instead.
   accscalar_t g_this_row = scalar_cast<accscalar_t>(saved_g[row]);
-   
-  // Write v gradients.  We are reusing values that were loaded earlier, so there 
+
+  // Write v gradients.  We are reusing values that were loaded earlier, so there
   // is an optimization opportunity here (store values persistently).
-  for(int j = tid; j < rowSize; j += stride ) 
+  for(int j = tid; j < rowSize; j += stride )
   {
-    accscalar_t grad_wj = scalar_cast<accscalar_t>(grad_w[j+rowStart]);  
-    accscalar_t saved_vj = scalar_cast<accscalar_t>(saved_v[j+rowStart]);  
+    accscalar_t grad_wj = scalar_cast<accscalar_t>(grad_w[j+rowStart]);
+    accscalar_t saved_vj = scalar_cast<accscalar_t>(saved_v[j+rowStart]);
     accscalar_t grad_vj = g_this_row*(rnorm*grad_wj - rnorm3*saved_vj*result);
     grad_v[j+rowStart] = scalar_cast<scalar_t>(grad_vj);
   }
 }
 
-template 
-  <typename scalar_t, 
+template
+  <typename scalar_t,
    typename accscalar_t>
 __global__ void weight_norm_bwd_last_dim_kernel
   (scalar_t* __restrict__ grad_v,
@@ -274,18 +274,18 @@ __global__ void weight_norm_bwd_last_dim_kernel
   if(fast_dim_location < fast_dim_size)
     while(slower_dims_location < slower_dims_size)
     {
-      accscalar_t grad_wi = scalar_cast<accscalar_t>(grad_w[currentIdx]); 
-      accscalar_t saved_vi = scalar_cast<accscalar_t>(saved_v[currentIdx]); 
+      accscalar_t grad_wi = scalar_cast<accscalar_t>(grad_w[currentIdx]);
+      accscalar_t saved_vi = scalar_cast<accscalar_t>(saved_v[currentIdx]);
       thread_sum += grad_wi*saved_vi; // AccumOp, could do Kahan here
       currentIdx += blockDim.y*fast_dim_size;
-      slower_dims_location += blockDim.y; 
+      slower_dims_location += blockDim.y;
     }
 
-  reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>()); 
+  reduce_block_into_lanes(s, thread_sum, blockDim.x, ReduceAdd<accscalar_t>());
   accscalar_t result = s[threadIdx.x];
 
   // Broadcast load; could use shared memory instead.
-  accscalar_t rnorm = 1.f/saved_norms[fast_dim_location];  
+  accscalar_t rnorm = 1.f/saved_norms[fast_dim_location];
   accscalar_t rnorm3 = rnorm*rnorm*rnorm;
 
   // Write g gradients.
@@ -301,13 +301,13 @@ __global__ void weight_norm_bwd_last_dim_kernel
   if(fast_dim_location < fast_dim_size)
     while(slower_dims_location < slower_dims_size)
     {
-      accscalar_t grad_wj = scalar_cast<accscalar_t>(grad_w[currentIdx]);  
-      accscalar_t saved_vj = scalar_cast<accscalar_t>(saved_v[currentIdx]);  
+      accscalar_t grad_wj = scalar_cast<accscalar_t>(grad_w[currentIdx]);
+      accscalar_t saved_vj = scalar_cast<accscalar_t>(saved_v[currentIdx]);
       accscalar_t grad_vj = g_this_col*(rnorm*grad_wj - rnorm3*saved_vj*result);
       grad_v[currentIdx] = scalar_cast<scalar_t>(grad_vj);
       currentIdx += blockDim.y*fast_dim_size;
-      slower_dims_location += blockDim.y; 
-    } 
+      slower_dims_location += blockDim.y;
+    }
 }
 
 } // anonymous namespace
@@ -315,7 +315,7 @@ __global__ void weight_norm_bwd_last_dim_kernel
 std::tuple<Tensor,Tensor> weight_norm_cuda
   (const Tensor & v,
    const Tensor & g,
-   int64_t dim) 
+   int64_t dim)
 {
   auto w = at::empty_like(v);
 
@@ -323,17 +323,17 @@ std::tuple<Tensor,Tensor> weight_norm_cuda
   // sends the unpacked g.data() as the argument.  In other words, we expect "g" is a bare Tensor here.
 
   // norms is only needed to stash for backward.
-  // g.scalar_type() may be at::ScalarType::Double, Float, or Half.  
+  // g.scalar_type() may be at::ScalarType::Double, Float, or Half.
   // If Half, stash norms as float.
   at::ScalarType AccType = g.scalar_type() == at::ScalarType::Half ?
                            at::ScalarType::Float : g.scalar_type();
-  // Will this create norms on the same device as g, regardless of what the thread's default 
+  // Will this create norms on the same device as g, regardless of what the thread's default
   // current device is?  I believe so, because Type::* functions are DeviceGuard()ed.
   auto norms = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(AccType));
 
   const int ndims = v.dim();
 
-  if(dim == 0) 
+  if(dim == 0)
   {
     // Find logical size of each flattened slowest-dim row
     int rowSize = 1;
@@ -343,21 +343,21 @@ std::tuple<Tensor,Tensor> weight_norm_cuda
     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF
-      (v.type(), 
-       "weight_norm_fwd_first_dim_kernel",  
+      (v.scalar_type(),
+       "weight_norm_fwd_first_dim_kernel",
        [&]
        {
          using accscalar_t = acc_type<scalar_t, true>;
 
          weight_norm_fwd_first_dim_kernel<scalar_t, accscalar_t>
-           <<<v.size(0), 
-              BLOCK, 
+           <<<v.size(0),
+              BLOCK,
               BLOCK*sizeof(accscalar_t),
               stream>>>
-           (w.data<scalar_t>(), 
+           (w.data<scalar_t>(),
             norms.data<accscalar_t>(),
-            v.data<scalar_t>(),  
-            g.data<scalar_t>(),  
+            v.data<scalar_t>(),
+            g.data<scalar_t>(),
             rowSize);
        });
   }
@@ -369,16 +369,16 @@ std::tuple<Tensor,Tensor> weight_norm_cuda
       slower_dims_size *= v.size(i);
 
     int fast_dim_size = v.size(ndims-1);
+
     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF
-      (v.type(), 
-       "weight_norm_fwd_last_dim_kernel",  
+      (v.scalar_type(),
+       "weight_norm_fwd_last_dim_kernel",
        [&]
        {
          using accscalar_t = acc_type<scalar_t, true>;
-        
+
          weight_norm_fwd_last_dim_kernel<scalar_t, accscalar_t>
            <<<(fast_dim_size+TILE_W-1)/TILE_W,
               dim3(TILE_W,TILE_H),
@@ -395,7 +395,7 @@ std::tuple<Tensor,Tensor> weight_norm_cuda
 
   // The kernel execution is asynchronous, so this will only catch errors on the kernel launch,
   // not the kernel's execution.  Errors in kernel execution aren't guaranteed to be caught
-  // until a later error check on a synchronizing CUDA call.  Unfortunately, without manually 
+  // until a later error check on a synchronizing CUDA call.  Unfortunately, without manually
   // synchronizing here, this is the best we can do.
   THCudaCheck(cudaGetLastError());
 
@@ -403,9 +403,9 @@ std::tuple<Tensor,Tensor> weight_norm_cuda
 }
 
 std::tuple<Tensor, Tensor> weight_norm_cuda_backward
-  (const Tensor & grad_w, 
-   const Tensor & saved_v, 
-   const Tensor & saved_g, 
+  (const Tensor & grad_w,
+   const Tensor & saved_v,
+   const Tensor & saved_g,
    const Tensor & saved_norms,
    int64_t dim)
 {
@@ -421,7 +421,7 @@ std::tuple<Tensor, Tensor> weight_norm_cuda_backward
 
   const int ndims = saved_v.dim();
 
-  if(dim == 0) 
+  if(dim == 0)
   {
     // Find logical size of each flattened slowest-dim row
     int rowSize = 1;
@@ -431,15 +431,15 @@ std::tuple<Tensor, Tensor> weight_norm_cuda_backward
     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF
-      (saved_v.type(), 
-       "weight_norm_bwd_first_dim_kernel",  
+      (saved_v.scalar_type(),
+       "weight_norm_bwd_first_dim_kernel",
        [&]
        {
          using accscalar_t = acc_type<scalar_t, true>;
 
         weight_norm_bwd_first_dim_kernel<scalar_t, accscalar_t>
-          <<<grad_w.size(0), 
-             BLOCK, 
+          <<<grad_w.size(0),
+             BLOCK,
              BLOCK*sizeof(accscalar_t),
               stream>>>
           (grad_v.data<scalar_t>(),
@@ -463,15 +463,15 @@ std::tuple<Tensor, Tensor> weight_norm_cuda_backward
     cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
     AT_DISPATCH_FLOATING_TYPES_AND_HALF
-      (saved_v.type(), 
-       "weight_norm_bwd_last_dim_kernel",  
+      (saved_v.scalar_type(),
+       "weight_norm_bwd_last_dim_kernel",
        [&]
        {
          using accscalar_t = acc_type<scalar_t, true>;
 
          weight_norm_bwd_last_dim_kernel<scalar_t, accscalar_t>
            <<<(fast_dim_size+TILE_W-1)/TILE_W,
-              dim3(TILE_W,TILE_H), 
+              dim3(TILE_W,TILE_H),
               (TILE_W*TILE_H + TILE_W)*sizeof(accscalar_t),
               stream>>>
            (grad_v.data<scalar_t>(),
@@ -487,7 +487,7 @@ std::tuple<Tensor, Tensor> weight_norm_cuda_backward
 
   // The kernel execution is asynchronous, so this will only catch errors on the kernel launch,
   // not the kernel's execution.  Errors in kernel execution aren't guaranteed to be caught
-  // until a later error check on a synchronizing CUDA call.  Unfortunately, without manually 
+  // until a later error check on a synchronizing CUDA call.  Unfortunately, without manually
   // synchronizing here, this is the best we can do.
   THCudaCheck(cudaGetLastError());
 
index 97a641c..809bd82 100644 (file)
@@ -83,7 +83,7 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c
 
 Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
   // checks are done in native/LinearAlgebra.cpp
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "baddbmm__mkl", [&] {
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "baddbmm__mkl", [&] {
       baddbmm_mkl_template<scalar_t>(self, batch1, batch2, beta, alpha);
     });
 
index 31f8697..6ffe4cb 100644 (file)
@@ -145,7 +145,7 @@ static inline void _fft_fill_with_conjugate_symmetry_(Tensor& input,
     {
       int tid = omp_get_thread_num();
       int64_t start = tid * num_slices_per_thread;
-      AT_DISPATCH_FLOATING_TYPES(input.type(), "_fft_fill_with_conjugate_symmetry", [&] {
+      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] {
         _fft_fill_with_conjugate_symmetry_slice<scalar_t>(input, signal_ndim, size_last_dim,
             last_dim_start_slice, start, std::min(num_slices_per_thread, num - start));
       });
@@ -153,7 +153,7 @@ static inline void _fft_fill_with_conjugate_symmetry_(Tensor& input,
     return;
   }
 #endif
-  AT_DISPATCH_FLOATING_TYPES(input.type(), "_fft_fill_with_conjugate_symmetry", [&] {
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "_fft_fill_with_conjugate_symmetry", [&] {
     _fft_fill_with_conjugate_symmetry_slice<scalar_t>(input, signal_ndim, size_last_dim,
         last_dim_start_slice, 0, num);
   });
@@ -291,4 +291,3 @@ Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim,
 }} // namespace at::native
 
 #endif
-
index f89dce3..d3278c6 100644 (file)
@@ -287,7 +287,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
   // TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0,
   // but this would take some work and doesn't seem particularly useful.
   AT_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0");
-  AT_CHECK(sparse_dim <= dims, 
+  AT_CHECK(sparse_dim <= dims,
     "sparse_dim must be less than or equal to self.dim()");
   at::TensorOptions sparse_options = self.options().layout(kSparse);
   std::vector<int64_t> sizes = self.sizes().vec();
@@ -376,7 +376,7 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) {
 
   int64_t i = -1;
   AT_DISPATCH_ALL_TYPES(
-      values.type(), "coalesce", [&] {
+      values.scalar_type(), "coalesce", [&] {
         int64_t prev = -1;
         int64_t blockSize = values.stride(0);
         scalar_t* values_ptr = values.data<scalar_t>();
@@ -483,7 +483,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
     // TODO: Re-audit this; it used to be an indexSelect directly into r_values
     at::index_select_out(r_values, t_view, 0, indices);
   } else {
-    AT_DISPATCH_ALL_TYPES(r_values.type(), "sparse_mask", [&] {
+    AT_DISPATCH_ALL_TYPES(r_values.scalar_type(), "sparse_mask", [&] {
       sparse_mask_out_cpu_kernel<scalar_t>(
         r_values,
         t,
index 7548961..6e31b23 100644 (file)
@@ -226,7 +226,7 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S
   auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
 
   AT_DISPATCH_ALL_TYPES(
-      t_values.type(), "cadd_sparse", [&] {
+      t_values.scalar_type(), "cadd_sparse", [&] {
         scalar_t* t_values_ptr = t_values.data<scalar_t>();
         scalar_t* s_values_ptr = s_values.data<scalar_t>();
         scalar_t* r_values_ptr = r_values.data<scalar_t>();
@@ -347,7 +347,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, SparseTensorRef
     }
   } else {
     AT_DISPATCH_ALL_TYPES(
-        values.type(), "add_dense_sparse", [&] {
+        values.scalar_type(), "add_dense_sparse", [&] {
           add_dense_sparse_worker_cpu<scalar_t>(r, value, sparse, indices, values);
         });
   }
@@ -435,7 +435,7 @@ SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor
     }
   } else {
     AT_DISPATCH_ALL_TYPES(
-        r_values.type(), "mul_out_sparse", [&] {
+        r_values.scalar_type(), "mul_out_sparse", [&] {
           auto r_accessor = r_values.accessor<scalar_t, 1>();
           auto t_accessor = t_values.accessor<scalar_t, 1>();
           auto s_accessor = s_values.accessor<scalar_t, 1>();
@@ -551,7 +551,7 @@ Tensor& s_addmm_out_sparse_dense_cpu(
   Tensor values      = sparse_._values();
 
   AT_DISPATCH_ALL_TYPES(
-      values.type(), "addmm_sparse_dense", [&] {
+      values.scalar_type(), "addmm_sparse_dense", [&] {
         s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
       }
   );
@@ -757,7 +757,7 @@ SparseTensor& _sspaddmm_out_cpu(
   int64_t newv_stride0 = newv.stride(0);
 
   AT_DISPATCH_ALL_TYPES(
-      values.type(), "sspmm", [&] {
+      values.scalar_type(), "sspmm", [&] {
         auto values_accessor = values.accessor<scalar_t, 1>();
         scalar_t* dense_ptr = dense.data<scalar_t>();
         scalar_t* newv_ptr = newv.data<scalar_t>();
index fc144d6..3d48a63 100644 (file)
@@ -95,7 +95,7 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
     dim3 grid(THCCeilDiv(newNnz, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128));
     dim3 block(32, 4);
     AT_DISPATCH_ALL_TYPES_AND(
-      at::ScalarType::Half,values.type(), "coalesce_sparse_cuda", [&] {
+      at::ScalarType::Half,values.scalar_type(), "coalesce_sparse_cuda", [&] {
           using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
           apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
             uniqueOffsets.data<int64_t>(),
index cac6835..b0f5657 100644 (file)
@@ -91,7 +91,7 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT
   // No half support, so we don't have to use CUDATypeConversion
   Tensor r__;
   AT_DISPATCH_FLOATING_TYPES(
-      values.type(), "addmm_sparse_cuda", [&] {
+      values.scalar_type(), "addmm_sparse_cuda", [&] {
         scalar_t cast_beta = beta.to<scalar_t>();
         scalar_t cast_alpha = alpha.to<scalar_t>();
         if (cast_beta == 0) {
@@ -296,7 +296,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
       AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions");
 
       AT_DISPATCH_ALL_TYPES_AND(
-        at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
+        at::ScalarType::Half, values.scalar_type(), "add_out_dense_sparse_cuda", [&] {
             apply::sparseElementwiseKernelScalar<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
               <<<grid, block, 0, stream>>>(
                 TensorCAddOp<scalar_t>(value.to<scalar_t>()),
@@ -310,7 +310,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
       values = values.contiguous();
 
       AT_DISPATCH_ALL_TYPES_AND(
-        at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
+        at::ScalarType::Half, values.scalar_type(), "add_out_dense_sparse_cuda", [&] {
             apply::sparseElementwiseKernel<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
               <<<grid, block, 0, stream>>>(
                 TensorCAddOp<scalar_t>(value.to<scalar_t>()),
@@ -324,7 +324,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
     // FIXME: at some point we can wrap the scale into indexAdd
     // NB: Purposely not inplace!
     AT_DISPATCH_ALL_TYPES_AND(
-      at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
+      at::ScalarType::Half, values.scalar_type(), "add_out_dense_sparse_cuda", [&] {
           if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
             values = values.mul(value);
           }
@@ -379,7 +379,7 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
   Tensor s_values_ = src._values();
 
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, s_values_.type(), "add_out_sparse_cuda", [&] {
+    at::ScalarType::Half, s_values_.scalar_type(), "add_out_sparse_cuda", [&] {
         if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
           s_values_ = s_values_.mul(value);
         }
@@ -449,7 +449,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
 
   LongTensor resultNnz = at::empty({1}, CUDA(kLong));
   AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, t_values_.type(), "mul_out_sparse_cuda", [&] {
+    at::ScalarType::Half, t_values_.scalar_type(), "mul_out_sparse_cuda", [&] {
         apply::valueSparseIntersectionKernel<TensorMulOp<scalar_t>, uint64_t, scalar_t>
           <<<grid, block, 0, stream>>>(
             TensorMulOp<scalar_t>(),
@@ -620,7 +620,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
       auto input_indices_ti = getTensorInfo<int64_t, int64_t>(input_indices_1D);
       auto input_indices_pos_ti = getTensorInfo<int64_t, int64_t>(input_indices_pos);
 
-      AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.type(), "_sparse_sum_backward_cuda", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.scalar_type(), "_sparse_sum_backward_cuda", [&] {
         auto grad_values_expand_ti = getTensorInfo<scalar_t, int64_t>(grad_values_expand);
         auto grad_input_values_ti = getTensorInfo<scalar_t, int64_t>(grad_input_values);
 
index 69f1776..cc97c03 100644 (file)
@@ -27,7 +27,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
   auto zero_dim = at::empty({}, type);
   zero_dim.fill_(2);
   zero_dim.exp_();
-  AT_DISPATCH_FLOATING_TYPES(zero_dim.type(), "test0", [&] {
+  AT_DISPATCH_FLOATING_TYPES(zero_dim.scalar_type(), "test0", [&] {
     ASSERT(zero_dim.data<scalar_t>()[0] == std::exp(2));
   });
 
@@ -50,7 +50,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
     }
   }
 
-  AT_DISPATCH_FLOATING_TYPES(a0.type(), "test1", [&] {
+  AT_DISPATCH_FLOATING_TYPES(a0.scalar_type(), "test1", [&] {
     CPU_tensor_apply2<scalar_t, scalar_t>(
         a0, a1, [](scalar_t& y, const scalar_t& x) { y = x * x; });
     CPU_tensor_apply2<double, scalar_t>(
@@ -62,7 +62,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
     }
   });
 
-  AT_DISPATCH_FLOATING_TYPES(a0.type(), "test2", [&] {
+  AT_DISPATCH_FLOATING_TYPES(a0.scalar_type(), "test2", [&] {
     CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
         a0, a1, a2, [](scalar_t& y, const scalar_t& x, const scalar_t& z) {
           y = x * x + z;
@@ -79,7 +79,7 @@ void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
     }
   });
 
-  AT_DISPATCH_FLOATING_TYPES(a0.type(), "test3", [&] {
+  AT_DISPATCH_FLOATING_TYPES(a0.scalar_type(), "test3", [&] {
     CPU_tensor_apply4<scalar_t, scalar_t, scalar_t, scalar_t>(
         a0,
         a1,
index 3ca0d3c..24c4da8 100644 (file)
@@ -101,7 +101,7 @@ TEST(TestScalar, TestScalar) {
   ASSERT_EQ(scalar_to_tensor(ones({}).item()).scalar_type(), kDouble);
 
   if (x.scalar_type() != ScalarType::Half) {
-    AT_DISPATCH_ALL_TYPES(x.type(), "foo", [&] {
+    AT_DISPATCH_ALL_TYPES(x.scalar_type(), "foo", [&] {
       scalar_t s = 1;
       std::stringstream ss;
       ASSERT_NO_THROW(
index e238071..e07b823 100755 (executable)
@@ -338,7 +338,7 @@ class TestCppExtension(common.TestCase):
 
         torch::Tensor half_test(torch::Tensor input) {
             auto output = torch::empty(1, input.options().dtype(torch::kFloat));
-            AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
+            AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "half_test", [&] {
                 half_test_kernel<scalar_t><<<1, 1>>>(
                     input.data<scalar_t>(),
                     output.data<float>());
index 8b853e5..d90bdfd 100644 (file)
@@ -119,7 +119,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
 
 static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
   return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
-      at::CPU(self->type), "epsilon", [] {
+      self->type, "epsilon", [] {
         return PyFloat_FromDouble(
             std::numeric_limits<
                 at::scalar_value_type<scalar_t>::type>::epsilon());
@@ -127,33 +127,33 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
 }
 
 static PyObject* THPFInfo_max(THPFInfo* self, void*) {
-  return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(at::CPU(self->type), "max", [] {
+  return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "max", [] {
     return PyFloat_FromDouble(
         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
   });
 }
 
 static PyObject* THPFInfo_min(THPFInfo* self, void*) {
-  return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(at::CPU(self->type), "min", [] {
+  return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "min", [] {
     return PyFloat_FromDouble(
         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
   });
 }
 
 static PyObject* THPIInfo_max(THPFInfo* self, void*) {
-  return AT_DISPATCH_INTEGRAL_TYPES(at::CPU(self->type), "max", [] {
+  return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
     return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
   });
 }
 
 static PyObject* THPIInfo_min(THPFInfo* self, void*) {
-  return AT_DISPATCH_INTEGRAL_TYPES(at::CPU(self->type), "min", [] {
+  return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
     return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
   });
 }
 
 static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
-  return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(at::CPU(self->type), "min", [] {
+  return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self->type, "min", [] {
     return PyFloat_FromDouble(
         std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
   });