From 5f82d59c0ae7d017bdb19a28b23fa1c19a0d6a9e Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Tue, 23 Apr 2019 12:43:24 -0700 Subject: [PATCH] Simplify argument test cases (#19593) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19593 Removes a lot of duplication Reviewed By: dzhulgakov Differential Revision: D15039887 fbshipit-source-id: e90fe024b84220dd337fdd314d8f7e3620baec28 --- .../core/op_registration/op_registration_test.cpp | 107 ++------------------- 1 file changed, 10 insertions(+), 97 deletions(-) diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index aabea77..3122421 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -248,7 +248,16 @@ struct ArgTypeTestKernel final : OperatorKernel { return output_; } - static void test(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema = "") { + static void test(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema) { + // test with explicitly specified schema + test_(input, inputExpectation, output, outputExpectation, schema); + + // test with inferred schema + test_(input, inputExpectation, output, outputExpectation, ""); + } + +private: + static void test_(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema = "") { auto registry = c10::RegisterOperators().op("_test::my_op" + schema, kernel(input, std::move(inputExpectation), std::move(output))); auto op = Dispatcher::singleton().findSchema("_test::my_op", ""); ASSERT_TRUE(op.has_value()); // assert schema is registered @@ -257,7 +266,6 @@ struct ArgTypeTestKernel final : OperatorKernel { outputExpectation(actualOutput[0]); } -private: InputType input_; std::function inputExpectation_; OutputType output_; @@ -270,44 +278,26 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { // primitive types ArgTypeTestKernel::test( 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, - 2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}); - ArgTypeTestKernel::test( - 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, 2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}, "(float a) -> float"); ArgTypeTestKernel::test( 1, [] (const int64_t& v) {EXPECT_EQ(1, v);}, - 2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}); - ArgTypeTestKernel::test( - 1, [] (const int64_t& v) {EXPECT_EQ(1, v);}, 2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}, "(int a) -> int"); ArgTypeTestKernel::test( true, [] (const bool& v) {EXPECT_EQ(true, v);}, - false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}); - ArgTypeTestKernel::test( - true, [] (const bool& v) {EXPECT_EQ(true, v);}, false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}, "(bool a) -> bool"); ArgTypeTestKernel::test( false, [] (const bool& v) {EXPECT_EQ(false, v);}, - true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}); - ArgTypeTestKernel::test( - false, [] (const bool& v) {EXPECT_EQ(false, v);}, true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}, "(bool a) -> bool"); ArgTypeTestKernel::test( "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);}, - "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}); - ArgTypeTestKernel::test( - "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);}, "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str a) -> str"); ArgTypeTestKernel::test( dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());}, - dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}); - ArgTypeTestKernel::test( - dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());}, dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}, "(Tensor a) -> Tensor"); @@ -315,44 +305,26 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { // optional types (with has_value() == true) ArgTypeTestKernel>::test( c10::optional(1.5), [] (const c10::optional& v) {EXPECT_EQ(1.5, v.value());}, - c10::optional(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}); - ArgTypeTestKernel>::test( - c10::optional(1.5), [] (const c10::optional& v) {EXPECT_EQ(1.5, v.value());}, c10::optional(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}, "(float? a) -> float?"); ArgTypeTestKernel>::test( c10::optional(1), [] (const c10::optional& v) {EXPECT_EQ(1, v.value());}, - c10::optional(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}); - ArgTypeTestKernel>::test( - c10::optional(1), [] (const c10::optional& v) {EXPECT_EQ(1, v.value());}, c10::optional(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}, "(int? a) -> int?"); ArgTypeTestKernel>::test( c10::optional(true), [] (const c10::optional& v) {EXPECT_EQ(true, v.value());}, - c10::optional(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}); - ArgTypeTestKernel>::test( - c10::optional(true), [] (const c10::optional& v) {EXPECT_EQ(true, v.value());}, c10::optional(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}, "(bool? a) -> bool?"); ArgTypeTestKernel>::test( c10::optional(false), [] (const c10::optional& v) {EXPECT_EQ(false, v.value());}, - c10::optional(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}); - ArgTypeTestKernel>::test( - c10::optional(false), [] (const c10::optional& v) {EXPECT_EQ(false, v.value());}, c10::optional(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}, "(bool? a) -> bool?"); ArgTypeTestKernel>::test( c10::optional("string1"), [] (const c10::optional& v) {EXPECT_EQ("string1", v.value());}, - c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}); - ArgTypeTestKernel>::test( - c10::optional("string1"), [] (const c10::optional& v) {EXPECT_EQ("string1", v.value());}, c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str? a) -> str?"); ArgTypeTestKernel>::test( c10::optional(dummyTensor(TensorType1())), [] (const c10::optional& v) {EXPECT_EQ(TensorType1(), v.value().type_id());}, - c10::optional(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}); - ArgTypeTestKernel>::test( - c10::optional(dummyTensor(TensorType1())), [] (const c10::optional& v) {EXPECT_EQ(TensorType1(), v.value().type_id());}, c10::optional(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}, "(Tensor? a) -> Tensor?"); @@ -360,44 +332,26 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { // optional types (with has_value() == false) ArgTypeTestKernel>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(float? a) -> float?"); ArgTypeTestKernel>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(int? a) -> int?"); ArgTypeTestKernel>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(bool? a) -> bool?"); ArgTypeTestKernel>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(bool? a) -> bool?"); ArgTypeTestKernel>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(str? a) -> str?"); ArgTypeTestKernel>::test( c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(Tensor? a) -> Tensor?"); @@ -405,40 +359,25 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { // list types (with empty list) ArgTypeTestKernel, std::vector>::test( c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, - std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());}); - ArgTypeTestKernel, std::vector>::test( - c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());}, "(float[] a) -> float[]"); ArgTypeTestKernel, std::vector>::test( c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, - std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}); - ArgTypeTestKernel, std::vector>::test( - c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}, "(int[] a) -> int[]"); // TODO Converting std::vector to ArrayRef doesn't work, so we // need to find an alternative // ArgTypeTestKernel, std::vector>::test( // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, - // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}); - // ArgTypeTestKernel, std::vector>::test( - // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}, // "(bool[] a) -> bool[]"); // ArgTypeTestKernel, std::vector>::test( // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, - // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}); - // ArgTypeTestKernel, std::vector>::test( - // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}, // "(bool[] a) -> bool[]"); // TODO We currently don't support str[] (i.e. string list) as type. Do we want to? // ArgTypeTestKernel, std::vector>::test( // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, - // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());}); - // ArgTypeTestKernel, std::vector>::test( - // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());}, // "(str[] a) -> str[]"); @@ -446,16 +385,10 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { // list types (with non-empty list) ArgTypeTestKernel, std::vector>::test( c10::ArrayRef({1.5, 2.5}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1.5, 2.5}), v);}, - std::vector({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector({3.5, 4.5}), v.toDoubleListRef());}); - ArgTypeTestKernel, std::vector>::test( - c10::ArrayRef({1.5, 2.5}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1.5, 2.5}), v);}, std::vector({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector({3.5, 4.5}), v.toDoubleListRef());}, "(float[] a) -> float[]"); ArgTypeTestKernel, std::vector>::test( c10::ArrayRef({1, 2}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v);}, - std::vector({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}); - ArgTypeTestKernel, std::vector>::test( - c10::ArrayRef({1, 2}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v);}, std::vector({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}, "(int[] a) -> int[]"); // TODO When fixing bool[] and str[] (see above), also add them here @@ -469,44 +402,24 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { EXPECT_EQ(2, v.toTensorListRef().size()); EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id()); EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id()); - }); - ArgTypeTestKernel, std::vector>::test( - c10::ArrayRef({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef v) { - EXPECT_EQ(2, v.size()); - EXPECT_EQ(TensorType1(), v[0].type_id()); - EXPECT_EQ(TensorType2(), v[1].type_id()); - }, - std::vector({dummyTensor(TensorType2()), dummyTensor(TensorType1())}), [] (const IValue& v) { - EXPECT_EQ(2, v.toTensorListRef().size()); - EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id()); }, "(Tensor[] a) -> Tensor[]"); // Test optional of list (with nullopt) ArgTypeTestKernel>, c10::optional>>::test( c10::optional>(c10::nullopt), [] (c10::optional> v) {EXPECT_FALSE(v.has_value());}, - c10::optional>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); - ArgTypeTestKernel>, c10::optional>>::test( - c10::optional>(c10::nullopt), [] (c10::optional> v) {EXPECT_FALSE(v.has_value());}, c10::optional>(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(int[]? a) -> int[]?"); // Test optional of list (with empty list) ArgTypeTestKernel>, c10::optional>>::test( c10::optional>(c10::ArrayRef{}), [] (c10::optional> v) {EXPECT_EQ(0, v.value().size());}, - c10::optional>(std::vector{}), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}); - ArgTypeTestKernel>, c10::optional>>::test( - c10::optional>(c10::ArrayRef{}), [] (c10::optional> v) {EXPECT_EQ(0, v.value().size());}, c10::optional>(std::vector{}), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}, "(int[]? a) -> int[]?"); // Test optional of list (with values) ArgTypeTestKernel>, c10::optional>>::test( c10::optional>({1, 2}), [] (c10::optional> v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v.value());}, - c10::optional>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}); - ArgTypeTestKernel>, c10::optional>>::test( - c10::optional>({1, 2}), [] (c10::optional> v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v.value());}, c10::optional>({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}, "(int[]? a) -> int[]?"); -- 2.7.4