Add tests for argument types (#19290)
authorSebastian Messmer <messmer@fb.com>
Fri, 19 Apr 2019 00:16:58 +0000 (17:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 00:20:13 +0000 (17:20 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19290

Add test cases for the supported argument types
And TODOs for some unsupported ones that we might want to support.

Reviewed By: dzhulgakov

Differential Revision: D14931920

fbshipit-source-id: c47bbb295a54ac9dc62569bf5c273368c834392c

aten/src/ATen/core/op_registration/op_registration_test.cpp
aten/src/ATen/core/op_registration/test_helpers.h
c10/util/ArrayRef.h

index 4f46b30..615555b 100644 (file)
@@ -15,6 +15,7 @@ using c10::OperatorKernel;
 using c10::kernel;
 using c10::dispatchKey;
 using c10::Dispatcher;
+using c10::IValue;
 using at::Tensor;
 
 namespace {
@@ -219,4 +220,271 @@ TEST(OperatorRegistrationTest, givenOpWithMultipleKernels_whenKernelsHaveSameDis
   }, "Tried to register multiple kernels with same dispatch key");
 }
 
+
+/**
+ * This is used to check that a given type works correctly when passed as input
+ * to or as output from a kernel.
+ *
+ * Call ArgTypeTestKernel<Input, Output>::test(input, inputExpectation, output, outputExpectation, schema)
+ * to test that a kernel with `Input` as input type and `Output` as output types,
+ * when called with `input` fulfills `inputExpectation` inside the kernel, then
+ * returns `output` and the returned value fulfills `outputExpectation`.
+ *
+ * `inputExpectation` and `outputExpectation` should be lambdas that run
+ * googletest expect macros (or use other ways to assert the expectation is met).
+ *
+ * Optionally, you can specify the argument list part of a function schema
+ * (e.g. "(Tensor a) -> Tensor") as an additional argument to use when
+ * registering the kernel. In this case, the operator registration logic will
+ * check that the kernel function signature matches the one you specified.
+ */
+template<class InputType, class OutputType = InputType>
+struct ArgTypeTestKernel final : OperatorKernel {
+  explicit ArgTypeTestKernel(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output)
+  : input_(std::move(input)), inputExpectation_(std::move(inputExpectation)), output_(std::move(output)) {}
+
+  OutputType operator()(InputType input) const {
+    inputExpectation_(std::move(input));
+    return output_;
+  }
+
+  static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema = "") {
+    auto registry = c10::RegisterOperators().op("_test::my_op" + schema, kernel<ArgTypeTestKernel>(input, std::move(inputExpectation), std::move(output)));
+    auto op = Dispatcher::singleton().findSchema("_test::my_op", "");
+    ASSERT_TRUE(op.has_value()); // assert schema is registered
+    auto actualOutput = callOp(*op, std::move(input));
+    EXPECT_EQ(1, actualOutput.size());
+    outputExpectation(actualOutput[0]);
+  }
+
+private:
+  InputType input_;
+  std::function<void(const InputType&)> inputExpectation_;
+  OutputType output_;
+  std::string schema_;
+};
+
+TEST(OperatorRegistrationTest, testAvailableArgTypes) {
+  // primitive types
+  ArgTypeTestKernel<double>::test(
+    1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
+    2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());});
+  ArgTypeTestKernel<double>::test(
+    1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
+    2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
+    "(float a) -> float");
+  ArgTypeTestKernel<int64_t>::test(
+    1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
+    2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());});
+  ArgTypeTestKernel<int64_t>::test(
+    1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
+    2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
+    "(int a) -> int");
+  ArgTypeTestKernel<bool>::test(
+    true, [] (const bool& v) {EXPECT_EQ(true, v);},
+    false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());});
+  ArgTypeTestKernel<bool>::test(
+    true, [] (const bool& v) {EXPECT_EQ(true, v);},
+    false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
+    "(bool a) -> bool");
+  ArgTypeTestKernel<bool>::test(
+    false, [] (const bool& v) {EXPECT_EQ(false, v);},
+    true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());});
+  ArgTypeTestKernel<bool>::test(
+    false, [] (const bool& v) {EXPECT_EQ(false, v);},
+    true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
+    "(bool a) -> bool");
+  ArgTypeTestKernel<std::string>::test(
+    "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
+    "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());});
+  ArgTypeTestKernel<std::string>::test(
+    "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
+    "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());},
+    "(str a) -> str");
+  ArgTypeTestKernel<Tensor>::test(
+    dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());},
+    dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());});
+  ArgTypeTestKernel<Tensor>::test(
+    dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());},
+    dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());},
+    "(Tensor a) -> Tensor");
+
+
+  // optional types (with has_value() == true)
+  ArgTypeTestKernel<c10::optional<double>>::test(
+    c10::optional<double>(1.5), [] (const c10::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
+    c10::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());});
+  ArgTypeTestKernel<c10::optional<double>>::test(
+    c10::optional<double>(1.5), [] (const c10::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
+    c10::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
+    "(float? a) -> float?");
+  ArgTypeTestKernel<c10::optional<int64_t>>::test(
+    c10::optional<int64_t>(1), [] (const c10::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
+    c10::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());});
+  ArgTypeTestKernel<c10::optional<int64_t>>::test(
+    c10::optional<int64_t>(1), [] (const c10::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
+    c10::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
+    "(int? a) -> int?");
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(true), [] (const c10::optional<bool>& v) {EXPECT_EQ(true, v.value());},
+    c10::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());});
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(true), [] (const c10::optional<bool>& v) {EXPECT_EQ(true, v.value());},
+    c10::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
+    "(bool? a) -> bool?");
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(false), [] (const c10::optional<bool>& v) {EXPECT_EQ(false, v.value());},
+    c10::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());});
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(false), [] (const c10::optional<bool>& v) {EXPECT_EQ(false, v.value());},
+    c10::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
+    "(bool? a) -> bool?");
+  ArgTypeTestKernel<c10::optional<std::string>>::test(
+    c10::optional<std::string>("string1"), [] (const c10::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
+    c10::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());});
+  ArgTypeTestKernel<c10::optional<std::string>>::test(
+    c10::optional<std::string>("string1"), [] (const c10::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
+    c10::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());},
+    "(str? a) -> str?");
+  ArgTypeTestKernel<c10::optional<Tensor>>::test(
+    c10::optional<Tensor>(dummyTensor(TensorType1())), [] (const c10::optional<Tensor>& v) {EXPECT_EQ(TensorType1(), v.value().type_id());},
+    c10::optional<Tensor>(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());});
+  ArgTypeTestKernel<c10::optional<Tensor>>::test(
+    c10::optional<Tensor>(dummyTensor(TensorType1())), [] (const c10::optional<Tensor>& v) {EXPECT_EQ(TensorType1(), v.value().type_id());},
+    c10::optional<Tensor>(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());},
+    "(Tensor? a) -> Tensor?");
+
+
+  // optional types (with has_value() == false)
+  ArgTypeTestKernel<c10::optional<double>>::test(
+    c10::optional<double>(), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<double>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
+  ArgTypeTestKernel<c10::optional<double>>::test(
+    c10::optional<double>(), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<double>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
+    "(float? a) -> float?");
+  ArgTypeTestKernel<c10::optional<int64_t>>::test(
+    c10::optional<int64_t>(), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<int64_t>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
+  ArgTypeTestKernel<c10::optional<int64_t>>::test(
+    c10::optional<int64_t>(), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<int64_t>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
+    "(int? a) -> int?");
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
+    "(bool? a) -> bool?");
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
+  ArgTypeTestKernel<c10::optional<bool>>::test(
+    c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
+    "(bool? a) -> bool?");
+  ArgTypeTestKernel<c10::optional<std::string>>::test(
+    c10::optional<std::string>(), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<std::string>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
+  ArgTypeTestKernel<c10::optional<std::string>>::test(
+    c10::optional<std::string>(), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<std::string>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
+    "(str? a) -> str?");
+  ArgTypeTestKernel<c10::optional<Tensor>>::test(
+    c10::optional<Tensor>(), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<Tensor>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
+  ArgTypeTestKernel<c10::optional<Tensor>>::test(
+    c10::optional<Tensor>(), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
+    c10::optional<Tensor>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
+    "(Tensor? a) -> Tensor?");
+
+
+  // list types (with empty list)
+  ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
+    c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
+    std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());});
+  ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
+    c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
+    std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());},
+    "(float[] a) -> float[]");
+  ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
+    c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
+    std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());});
+  ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
+    c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
+    std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());},
+    "(int[] a) -> int[]");
+  // TODO Converting std::vector<bool> to ArrayRef<bool> doesn't work, so we
+  //      need to find an alternative
+  // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
+  //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
+  //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());});
+  // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
+  //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
+  //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());},
+  //   "(bool[] a) -> bool[]");
+  // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
+  //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
+  //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());});
+  // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
+  //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
+  //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());},
+  //   "(bool[] a) -> bool[]");
+  // TODO We currently don't support str[] (i.e. string list) as type. Do we want to?
+  // ArgTypeTestKernel<c10::ArrayRef<std::string>, std::vector<std::string>>::test(
+  //   c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
+  //   std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());});
+  // ArgTypeTestKernel<c10::ArrayRef<std::string>, std::vector<std::string>>::test(
+  //   c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
+  //   std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());},
+  //   "(str[] a) -> str[]");
+
+
+  // list types (with non-empty list)
+  ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
+    c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {EXPECT_EQ(c10::ArrayRef<double>({1.5, 2.5}), v);},
+    std::vector<double>({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector<double>({3.5, 4.5}), v.toDoubleListRef());});
+  ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
+    c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {EXPECT_EQ(c10::ArrayRef<double>({1.5, 2.5}), v);},
+    std::vector<double>({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector<double>({3.5, 4.5}), v.toDoubleListRef());},
+    "(float[] a) -> float[]");
+  ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
+    c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v);},
+    std::vector<int64_t>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());});
+  ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
+    c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v);},
+    std::vector<int64_t>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());},
+    "(int[] a) -> int[]");
+  // TODO When fixing bool[] and str[] (see above), also add them here
+  ArgTypeTestKernel<c10::ArrayRef<Tensor>, std::vector<Tensor>>::test(
+    c10::ArrayRef<Tensor>({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef<Tensor> v) {
+      EXPECT_EQ(2, v.size());
+      EXPECT_EQ(TensorType1(), v[0].type_id());
+      EXPECT_EQ(TensorType2(), v[1].type_id());
+    },
+    std::vector<Tensor>({dummyTensor(TensorType2()), dummyTensor(TensorType1())}), [] (const IValue& v) {
+      EXPECT_EQ(2, v.toTensorListRef().size());
+      EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id());
+      EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id());
+    });
+  ArgTypeTestKernel<c10::ArrayRef<Tensor>, std::vector<Tensor>>::test(
+    c10::ArrayRef<Tensor>({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef<Tensor> v) {
+      EXPECT_EQ(2, v.size());
+      EXPECT_EQ(TensorType1(), v[0].type_id());
+      EXPECT_EQ(TensorType2(), v[1].type_id());
+    },
+    std::vector<Tensor>({dummyTensor(TensorType2()), dummyTensor(TensorType1())}), [] (const IValue& v) {
+      EXPECT_EQ(2, v.toTensorListRef().size());
+      EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id());
+      EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id());
+    },
+    "(Tensor[] a) -> Tensor[]");
+
+
+  // TODO Do we want to support list of optional / optional of list ?
+
+  // TODO Add tests for dict types
+}
+
 }
index ecb7dc8..a678741 100644 (file)
@@ -8,9 +8,38 @@
 #include <ATen/core/ivalue.h>
 #include <c10/core/CPUAllocator.h>
 
+namespace detail {
+// InputToIValue takes a value and converts it to an IValue to be put on a stack.
+template<class T>
+struct InputToIValue final {
+  template<class T_>
+  static c10::IValue call(T_&& v) {
+    return c10::IValue(std::forward<T_>(v));
+  }
+};
+template<class T>
+struct InputToIValue<c10::optional<T>> final {
+  template<class T_>
+  static c10::IValue call(T_&& v) {
+    if (v.has_value()) {
+      return c10::IValue(std::move(*v));
+    } else {
+      return c10::IValue();
+    }
+  }
+};
+template<class T>
+struct InputToIValue<c10::ArrayRef<T>> final {
+  template<class T_>
+  static c10::IValue call(T_&& v) {
+    return c10::IValue(v.vec());
+  }
+};
+}
+
 template<class... Inputs>
 inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
-  return {std::forward<Inputs>(inputs)...};
+  return {detail::InputToIValue<c10::guts::decay_t<Inputs>>::call(std::forward<Inputs>(inputs))...};
 }
 
 inline at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) {
index 6acb532..989e121 100644 (file)
@@ -80,9 +80,13 @@ class ArrayRef final {
       : Data(Vec.data()), Length(Vec.size()) {}
 
   /// Construct an ArrayRef from a std::vector.
+  // The enable_if stuff here makes sure that this isn't used for std::vector<bool>,
+  // because ArrayRef can't work on a std::vector<bool> bitfield.
   template <typename A>
   /* implicit */ ArrayRef(const std::vector<T, A>& Vec)
-      : Data(Vec.data()), Length(Vec.size()) {}
+      : Data(Vec.data()), Length(Vec.size()) {
+    static_assert(!std::is_same<T, bool>::value, "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
+  }
 
   /// Construct an ArrayRef from a std::array
   template <size_t N>