Add backwards compatibility and other fixes to Dispatch macros. (#17996)
authorGregory Chanan <gchanan@fb.com>
Fri, 15 Mar 2019 21:16:22 +0000 (14:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Mar 2019 21:21:46 +0000 (14:21 -0700)
Summary:
Changes:
1) https://github.com/pytorch/pytorch/pull/17527 changed dispatch macros to be ScalarType based instead of at::Type based.  This broke cpp extensions that relied on dispatch macros.  Since IMO these should be ScalarType based (and some extensions have already updated), we allow either at::Type or at::ScalarType to be passed, but passing at::Type will result in a deprecated warning.

2) Reintroduce macros that were deleted (AT_DISPATCH_ALL_TYPES_AND_HALF, AT_DISPATCH_COMPLEX_TYPES, AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX, AT_DISPATCH_ALL_TYPES_AND_COMPLEX); the AND_HALF ones now give a deprecated warning because there are more extensible macros that were introduced in their place.

3) Makes AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND into a ScalarType based macro (and updates usages).  This was the result of a logical merge conflicts.

4) Adds a new macro, C10_DEPRECATED_MESSAGE for passing a deprecated message to the compiler.  I didn't spend much time seeing if this can be enabled for versions before C++14.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17996

Reviewed By: ezyang

Differential Revision: D14446203

Pulled By: gchanan

fbshipit-source-id: 1da56e2e9c15aa8f913ebbf6bf1110c5b6dc375e

aten/src/ATen/Dispatch.h
aten/src/ATen/native/Scalar.cpp
c10/util/Deprecated.h

index 764cbc6..44a9a55 100644 (file)
     return __VA_ARGS__();                          \
   }
 
+namespace detail {
+
+inline at::ScalarType scalar_type(at::ScalarType s) {
+  return s;
+}
+
+C10_DEPRECATED_MESSAGE("passing at::Type to an AT_DISPATCH macro is deprecated, " \
+                       "pass an at::ScalarType instead")
+inline at::ScalarType scalar_type(const at::Type &t) {
+  return t.scalarType();
+}
+
+C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, " \
+                       "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
+inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
+
+C10_DEPRECATED_MESSAGE("AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "            \
+                       "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) " \
+                       "instead")
+inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
+
+}
+
 #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                          \
   [&] {                                                                      \
-    switch (TYPE) {                                                          \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...)                 \
   [&] {                                                                      \
-    switch (TYPE) {                                                          \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...)              \
   [&] {                                                                      \
-    switch (TYPE) {                                                          \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(                                                  \
           at::ScalarType::ComplexHalf, std::complex<at::Half>, __VA_ARGS__)  \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)                          \
   [&] {                                                                      \
-    switch (TYPE) {                                                          \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
     }                                                                        \
   }()
 
 #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                               \
   [&] {                                                                      \
-    switch (TYPE) {                                                          \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
+      default:                                                               \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
+    }                                                                        \
+  }()
+
+#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...)                      \
+  [&] {                                                                      \
+    detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF();                     \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
+      default:                                                               \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
+    }                                                                        \
+  }()
+
+#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...)                           \
+  [&] {                                                                      \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
+      AT_PRIVATE_CASE_TYPE(                                                  \
+          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)    \
+      AT_PRIVATE_CASE_TYPE(                                                  \
+          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)  \
+      default:                                                               \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
+    }                                                                        \
+  }()
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...)                   \
+  [&] {                                                                      \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(                                                  \
+          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)    \
+      AT_PRIVATE_CASE_TYPE(                                                  \
+          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)  \
       default:                                                               \
-        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");      \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
     }                                                                        \
   }()
 
+#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...)          \
+  [&] {                                                                      \
+    detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX()          \
+    const auto& the_type = TYPE;                                             \
+    at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
+    switch (_st) {                                                           \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(                                                  \
+          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)    \
+      AT_PRIVATE_CASE_TYPE(                                                  \
+          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)  \
+      default:                                                               \
+        AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
+    }                                                                        \
+  }()
+
+
 template <at::ScalarType N>
 struct MyTemplate;
 
@@ -107,8 +218,7 @@ struct MyTemplate<at::ScalarType::Bool> {
 
 #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)        \
   [&] {                                                                                         \
-    const at::Type& the_type = TYPE;                                                            \
-    switch (the_type.scalarType()) {                                                            \
+    switch (TYPE) {                                                                             \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)                          \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)                           \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)                         \
@@ -123,6 +233,6 @@ struct MyTemplate<at::ScalarType::Bool> {
       AT_PRIVATE_CASE_TYPE(                                                                     \
           at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)                     \
       default:                                                                                  \
-        AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'");                    \
+        AT_ERROR(#NAME, " not implemented for '", TYPE, "'");                                   \
     }                                                                                           \
   }()
index 8e4ae10..918f4c3 100644 (file)
@@ -19,7 +19,7 @@ Scalar item(const Tensor& self) {
 Scalar _local_scalar_dense_cpu(const Tensor& self) {
   Scalar r;
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
-    at::ScalarType::Half, at::ScalarType::Bool, self.type(), "_local_scalar_dense_cpu", [&] {
+    at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
         scalar_t value = *self.data<scalar_t>();
         r = Scalar(value);
       });
index d2c3776..294fb3a 100644 (file)
 // portable way to declare something deprecated.
 #if defined(__cplusplus) && __cplusplus > 201402L
 # define C10_DEPRECATED [[deprecated]]
+# define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]]
 #elif defined(__GNUC__)
 # define C10_DEPRECATED __attribute__((deprecated))
+// TODO: is there some way to implement this?
+# define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated))
 #elif defined(_MSC_VER)
 # define C10_DEPRECATED __declspec(deprecated)
+// TODO: is there some way to implement this?
+# define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated)
 #else
 # warning "You need to implement C10_DEPRECATED for this compiler"
 # define C10_DEPRECATED