Add a correctness check for C++ types to custom operators (#15247)
authorPeter Goldsborough <psag@fb.com>
Tue, 18 Dec 2018 00:07:14 +0000 (16:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 00:17:27 +0000 (16:17 -0800)
Summary:
The JIT uses `int64_t` for its integer type and `double` for its floating point type, but users quite often want to write `int` or `float` and that currently fails in not-so-nice ways for custom ops. This PR adds a simple `static_assert` to catch these common failure cases.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15247

Differential Revision: D13493941

Pulled By: goldsborough

fbshipit-source-id: c1cd0d10ab5838c75f167c0bdb57e45a0bc1344e

torch/csrc/jit/custom_operator.h

index 051a7bf..33e3865 100644 (file)
@@ -9,10 +9,31 @@
 #include <c10/util/Metaprogramming.h>
 #include <c10/util/TypeList.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 namespace detail {
+/// Checks the static C++ type `T` for correctness to catch common error cases.
+template <typename T>
+void checkStaticTypes() {
+  // Give nice error messages for some of the common error cases.
+  // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
+  static_assert(
+      !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
+      "INVALID TYPE: Only int64_t is supported as an integral argument type");
+  static_assert(
+      !std::is_same<T, float>::value,
+      "INVALID TYPE: float is not supported as an argument type, use double instead");
+}
+
+template <typename First, typename Second, typename... Rest>
+void checkStaticTypes() {
+  checkStaticTypes<First>();
+  checkStaticTypes<Second, Rest...>();
+}
+
 template <typename... Ts, size_t... Is>
 std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
+  checkStaticTypes<decay_t<Ts>...>();
   // Arguments are named "_<index>"
   return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
 }
@@ -33,6 +54,7 @@ std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
 /// Create a single-element `vector` for simple (non-tuple) return types.
 template <typename ReturnType>
 std::vector<Argument> createReturns(ReturnType*) {
+  checkStaticTypes<decay_t<ReturnType>>();
   return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
 }
 
@@ -50,9 +72,11 @@ std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
 template <typename FunctionTraits>
 FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
   using ReturnType = typename FunctionTraits::return_type;
+
   auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
       typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
   auto returns = createReturns(static_cast<ReturnType*>(nullptr));
+
   return {name, arguments, returns};
 }