Simplify argument test cases (#19593)
authorSebastian Messmer <messmer@fb.com>
Tue, 23 Apr 2019 19:43:24 +0000 (12:43 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 23 Apr 2019 19:58:35 +0000 (12:58 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19593

Removes a lot of duplication

Reviewed By: dzhulgakov

Differential Revision: D15039887

fbshipit-source-id: e90fe024b84220dd337fdd314d8f7e3620baec28

aten/src/ATen/core/op_registration/op_registration_test.cpp

index aabea77..3122421 100644 (file)
@@ -248,7 +248,16 @@ struct ArgTypeTestKernel final : OperatorKernel {
     return output_;
   }
 
-  static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema = "") {
+  static void test(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema) {
+    // test with explicitly specified schema
+    test_(input, inputExpectation, output, outputExpectation, schema);
+
+    // test with inferred schema
+    test_(input, inputExpectation, output, outputExpectation, "");
+  }
+
+private:
+  static void test_(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const IValue&)> outputExpectation, const std::string& schema = "") {
     auto registry = c10::RegisterOperators().op("_test::my_op" + schema, kernel<ArgTypeTestKernel>(input, std::move(inputExpectation), std::move(output)));
     auto op = Dispatcher::singleton().findSchema("_test::my_op", "");
     ASSERT_TRUE(op.has_value()); // assert schema is registered
@@ -257,7 +266,6 @@ struct ArgTypeTestKernel final : OperatorKernel {
     outputExpectation(actualOutput[0]);
   }
 
-private:
   InputType input_;
   std::function<void(const InputType&)> inputExpectation_;
   OutputType output_;
@@ -270,44 +278,26 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
   // primitive types
   ArgTypeTestKernel<double>::test(
     1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
-    2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());});
-  ArgTypeTestKernel<double>::test(
-    1.5, [] (const double& v) {EXPECT_EQ(1.5, v);},
     2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
     "(float a) -> float");
   ArgTypeTestKernel<int64_t>::test(
     1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
-    2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());});
-  ArgTypeTestKernel<int64_t>::test(
-    1, [] (const int64_t& v) {EXPECT_EQ(1, v);},
     2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
     "(int a) -> int");
   ArgTypeTestKernel<bool>::test(
     true, [] (const bool& v) {EXPECT_EQ(true, v);},
-    false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());});
-  ArgTypeTestKernel<bool>::test(
-    true, [] (const bool& v) {EXPECT_EQ(true, v);},
     false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
     "(bool a) -> bool");
   ArgTypeTestKernel<bool>::test(
     false, [] (const bool& v) {EXPECT_EQ(false, v);},
-    true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());});
-  ArgTypeTestKernel<bool>::test(
-    false, [] (const bool& v) {EXPECT_EQ(false, v);},
     true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
     "(bool a) -> bool");
   ArgTypeTestKernel<std::string>::test(
     "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
-    "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());});
-  ArgTypeTestKernel<std::string>::test(
-    "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);},
     "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());},
     "(str a) -> str");
   ArgTypeTestKernel<Tensor>::test(
     dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());},
-    dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());});
-  ArgTypeTestKernel<Tensor>::test(
-    dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());},
     dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());},
     "(Tensor a) -> Tensor");
 
@@ -315,44 +305,26 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
   // optional types (with has_value() == true)
   ArgTypeTestKernel<c10::optional<double>>::test(
     c10::optional<double>(1.5), [] (const c10::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
-    c10::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());});
-  ArgTypeTestKernel<c10::optional<double>>::test(
-    c10::optional<double>(1.5), [] (const c10::optional<double>& v) {EXPECT_EQ(1.5, v.value());},
     c10::optional<double>(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());},
     "(float? a) -> float?");
   ArgTypeTestKernel<c10::optional<int64_t>>::test(
     c10::optional<int64_t>(1), [] (const c10::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
-    c10::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());});
-  ArgTypeTestKernel<c10::optional<int64_t>>::test(
-    c10::optional<int64_t>(1), [] (const c10::optional<int64_t>& v) {EXPECT_EQ(1, v.value());},
     c10::optional<int64_t>(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());},
     "(int? a) -> int?");
   ArgTypeTestKernel<c10::optional<bool>>::test(
     c10::optional<bool>(true), [] (const c10::optional<bool>& v) {EXPECT_EQ(true, v.value());},
-    c10::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());});
-  ArgTypeTestKernel<c10::optional<bool>>::test(
-    c10::optional<bool>(true), [] (const c10::optional<bool>& v) {EXPECT_EQ(true, v.value());},
     c10::optional<bool>(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());},
     "(bool? a) -> bool?");
   ArgTypeTestKernel<c10::optional<bool>>::test(
     c10::optional<bool>(false), [] (const c10::optional<bool>& v) {EXPECT_EQ(false, v.value());},
-    c10::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());});
-  ArgTypeTestKernel<c10::optional<bool>>::test(
-    c10::optional<bool>(false), [] (const c10::optional<bool>& v) {EXPECT_EQ(false, v.value());},
     c10::optional<bool>(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());},
     "(bool? a) -> bool?");
   ArgTypeTestKernel<c10::optional<std::string>>::test(
     c10::optional<std::string>("string1"), [] (const c10::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
-    c10::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());});
-  ArgTypeTestKernel<c10::optional<std::string>>::test(
-    c10::optional<std::string>("string1"), [] (const c10::optional<std::string>& v) {EXPECT_EQ("string1", v.value());},
     c10::optional<std::string>("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());},
     "(str? a) -> str?");
   ArgTypeTestKernel<c10::optional<Tensor>>::test(
     c10::optional<Tensor>(dummyTensor(TensorType1())), [] (const c10::optional<Tensor>& v) {EXPECT_EQ(TensorType1(), v.value().type_id());},
-    c10::optional<Tensor>(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());});
-  ArgTypeTestKernel<c10::optional<Tensor>>::test(
-    c10::optional<Tensor>(dummyTensor(TensorType1())), [] (const c10::optional<Tensor>& v) {EXPECT_EQ(TensorType1(), v.value().type_id());},
     c10::optional<Tensor>(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());},
     "(Tensor? a) -> Tensor?");
 
@@ -360,44 +332,26 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
   // optional types (with has_value() == false)
   ArgTypeTestKernel<c10::optional<double>>::test(
     c10::optional<double>(), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<double>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<double>>::test(
-    c10::optional<double>(), [] (const c10::optional<double>& v) {EXPECT_FALSE(v.has_value());},
     c10::optional<double>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(float? a) -> float?");
   ArgTypeTestKernel<c10::optional<int64_t>>::test(
     c10::optional<int64_t>(), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<int64_t>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<int64_t>>::test(
-    c10::optional<int64_t>(), [] (const c10::optional<int64_t>& v) {EXPECT_FALSE(v.has_value());},
     c10::optional<int64_t>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(int? a) -> int?");
   ArgTypeTestKernel<c10::optional<bool>>::test(
     c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<bool>>::test(
-    c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
     c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(bool? a) -> bool?");
   ArgTypeTestKernel<c10::optional<bool>>::test(
     c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<bool>>::test(
-    c10::optional<bool>(), [] (const c10::optional<bool>& v) {EXPECT_FALSE(v.has_value());},
     c10::optional<bool>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(bool? a) -> bool?");
   ArgTypeTestKernel<c10::optional<std::string>>::test(
     c10::optional<std::string>(), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<std::string>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<std::string>>::test(
-    c10::optional<std::string>(), [] (const c10::optional<std::string>& v) {EXPECT_FALSE(v.has_value());},
     c10::optional<std::string>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(str? a) -> str?");
   ArgTypeTestKernel<c10::optional<Tensor>>::test(
     c10::optional<Tensor>(), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<Tensor>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<Tensor>>::test(
-    c10::optional<Tensor>(), [] (const c10::optional<Tensor>& v) {EXPECT_FALSE(v.has_value());},
     c10::optional<Tensor>(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(Tensor? a) -> Tensor?");
 
@@ -405,40 +359,25 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
   // list types (with empty list)
   ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
     c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
-    std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());});
-  ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
-    c10::ArrayRef<double>(), [] (c10::ArrayRef<double> v) {EXPECT_EQ(0, v.size());},
     std::vector<double>(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());},
     "(float[] a) -> float[]");
   ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
     c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
-    std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());});
-  ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
-    c10::ArrayRef<int64_t>(), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(0, v.size());},
     std::vector<int64_t>(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());},
     "(int[] a) -> int[]");
   // TODO Converting std::vector<bool> to ArrayRef<bool> doesn't work, so we
   //      need to find an alternative
   // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
   //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
-  //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());});
-  // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
-  //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
   //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());},
   //   "(bool[] a) -> bool[]");
   // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
   //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
-  //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());});
-  // ArgTypeTestKernel<c10::ArrayRef<bool>, std::vector<bool>>::test(
-  //   c10::ArrayRef<bool>(), [] (c10::ArrayRef<bool> v) {EXPECT_EQ(0, v.size());},
   //   std::vector<bool>(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());},
   //   "(bool[] a) -> bool[]");
   // TODO We currently don't support str[] (i.e. string list) as type. Do we want to?
   // ArgTypeTestKernel<c10::ArrayRef<std::string>, std::vector<std::string>>::test(
   //   c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
-  //   std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());});
-  // ArgTypeTestKernel<c10::ArrayRef<std::string>, std::vector<std::string>>::test(
-  //   c10::ArrayRef<std::string>(), [] (c10::ArrayRef<std::string> v) {EXPECT_EQ(0, v.size());},
   //   std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());},
   //   "(str[] a) -> str[]");
 
@@ -446,16 +385,10 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
   // list types (with non-empty list)
   ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
     c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {EXPECT_EQ(c10::ArrayRef<double>({1.5, 2.5}), v);},
-    std::vector<double>({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector<double>({3.5, 4.5}), v.toDoubleListRef());});
-  ArgTypeTestKernel<c10::ArrayRef<double>, std::vector<double>>::test(
-    c10::ArrayRef<double>({1.5, 2.5}), [] (c10::ArrayRef<double> v) {EXPECT_EQ(c10::ArrayRef<double>({1.5, 2.5}), v);},
     std::vector<double>({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector<double>({3.5, 4.5}), v.toDoubleListRef());},
     "(float[] a) -> float[]");
   ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
     c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v);},
-    std::vector<int64_t>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());});
-  ArgTypeTestKernel<c10::ArrayRef<int64_t>, std::vector<int64_t>>::test(
-    c10::ArrayRef<int64_t>({1, 2}), [] (c10::ArrayRef<int64_t> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v);},
     std::vector<int64_t>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());},
     "(int[] a) -> int[]");
   // TODO When fixing bool[] and str[] (see above), also add them here
@@ -469,44 +402,24 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
       EXPECT_EQ(2, v.toTensorListRef().size());
       EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id());
       EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id());
-    });
-  ArgTypeTestKernel<c10::ArrayRef<Tensor>, std::vector<Tensor>>::test(
-    c10::ArrayRef<Tensor>({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef<Tensor> v) {
-      EXPECT_EQ(2, v.size());
-      EXPECT_EQ(TensorType1(), v[0].type_id());
-      EXPECT_EQ(TensorType2(), v[1].type_id());
-    },
-    std::vector<Tensor>({dummyTensor(TensorType2()), dummyTensor(TensorType1())}), [] (const IValue& v) {
-      EXPECT_EQ(2, v.toTensorListRef().size());
-      EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id());
-      EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id());
     },
     "(Tensor[] a) -> Tensor[]");
 
   // Test optional of list (with nullopt)
   ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
     c10::optional<c10::ArrayRef<int64_t>>(c10::nullopt), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_FALSE(v.has_value());},
-    c10::optional<std::vector<int64_t>>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());});
-  ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
-    c10::optional<c10::ArrayRef<int64_t>>(c10::nullopt), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_FALSE(v.has_value());},
     c10::optional<std::vector<int64_t>>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());},
     "(int[]? a) -> int[]?");
 
   // Test optional of list (with empty list)
   ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
     c10::optional<c10::ArrayRef<int64_t>>(c10::ArrayRef<int64_t>{}), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_EQ(0, v.value().size());},
-    c10::optional<std::vector<int64_t>>(std::vector<int64_t>{}), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());});
-  ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
-    c10::optional<c10::ArrayRef<int64_t>>(c10::ArrayRef<int64_t>{}), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_EQ(0, v.value().size());},
     c10::optional<std::vector<int64_t>>(std::vector<int64_t>{}), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());},
     "(int[]? a) -> int[]?");
 
   // Test optional of list (with values)
   ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
     c10::optional<c10::ArrayRef<int64_t>>({1, 2}), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v.value());},
-    c10::optional<std::vector<int64_t>>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());});
-  ArgTypeTestKernel<c10::optional<c10::ArrayRef<int64_t>>, c10::optional<std::vector<int64_t>>>::test(
-    c10::optional<c10::ArrayRef<int64_t>>({1, 2}), [] (c10::optional<c10::ArrayRef<int64_t>> v) {EXPECT_EQ(c10::ArrayRef<int64_t>({1, 2}), v.value());},
     c10::optional<std::vector<int64_t>>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector<int64_t>({3, 4}), v.toIntListRef());},
     "(int[]? a) -> int[]?");