From: Sebastian Messmer Date: Fri, 19 Apr 2019 00:16:58 +0000 (-0700) Subject: Add tests for argument types (#19290) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~151 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ce969c0bc42f53cbc36da05ec66b24122105d628;p=platform%2Fupstream%2Fpytorch.git Add tests for argument types (#19290) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19290 Add test cases for the supported argument types And TODOs for some unsupported ones that we might want to support. Reviewed By: dzhulgakov Differential Revision: D14931920 fbshipit-source-id: c47bbb295a54ac9dc62569bf5c273368c834392c --- diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 4f46b30..615555b 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -15,6 +15,7 @@ using c10::OperatorKernel; using c10::kernel; using c10::dispatchKey; using c10::Dispatcher; +using c10::IValue; using at::Tensor; namespace { @@ -219,4 +220,271 @@ TEST(OperatorRegistrationTest, givenOpWithMultipleKernels_whenKernelsHaveSameDis }, "Tried to register multiple kernels with same dispatch key"); } + +/** + * This is used to check that a given type works correctly when passed as input + * to or as output from a kernel. + * + * Call ArgTypeTestKernel::test(input, inputExpectation, output, outputExpectation, schema) + * to test that a kernel with `Input` as input type and `Output` as output types, + * when called with `input` fulfills `inputExpectation` inside the kernel, then + * returns `output` and the returned value fulfills `outputExpectation`. + * + * `inputExpectation` and `outputExpectation` should be lambdas that run + * googletest expect macros (or use other ways to assert the expectation is met). + * + * Optionally, you can specify the argument list part of a function schema + * (e.g. "(Tensor a) -> Tensor") as an additional argument to use when + * registering the kernel. In this case, the operator registration logic will + * check that the kernel function signature matches the one you specified. + */ +template +struct ArgTypeTestKernel final : OperatorKernel { + explicit ArgTypeTestKernel(InputType input, std::function inputExpectation, OutputType output) + : input_(std::move(input)), inputExpectation_(std::move(inputExpectation)), output_(std::move(output)) {} + + OutputType operator()(InputType input) const { + inputExpectation_(std::move(input)); + return output_; + } + + static void test(InputType input, std::function inputExpectation, OutputType output, std::function outputExpectation, const std::string& schema = "") { + auto registry = c10::RegisterOperators().op("_test::my_op" + schema, kernel(input, std::move(inputExpectation), std::move(output))); + auto op = Dispatcher::singleton().findSchema("_test::my_op", ""); + ASSERT_TRUE(op.has_value()); // assert schema is registered + auto actualOutput = callOp(*op, std::move(input)); + EXPECT_EQ(1, actualOutput.size()); + outputExpectation(actualOutput[0]); + } + +private: + InputType input_; + std::function inputExpectation_; + OutputType output_; + std::string schema_; +}; + +TEST(OperatorRegistrationTest, testAvailableArgTypes) { + // primitive types + ArgTypeTestKernel::test( + 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, + 2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}); + ArgTypeTestKernel::test( + 1.5, [] (const double& v) {EXPECT_EQ(1.5, v);}, + 2.5, [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}, + "(float a) -> float"); + ArgTypeTestKernel::test( + 1, [] (const int64_t& v) {EXPECT_EQ(1, v);}, + 2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}); + ArgTypeTestKernel::test( + 1, [] (const int64_t& v) {EXPECT_EQ(1, v);}, + 2, [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}, + "(int a) -> int"); + ArgTypeTestKernel::test( + true, [] (const bool& v) {EXPECT_EQ(true, v);}, + false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}); + ArgTypeTestKernel::test( + true, [] (const bool& v) {EXPECT_EQ(true, v);}, + false, [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}, + "(bool a) -> bool"); + ArgTypeTestKernel::test( + false, [] (const bool& v) {EXPECT_EQ(false, v);}, + true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}); + ArgTypeTestKernel::test( + false, [] (const bool& v) {EXPECT_EQ(false, v);}, + true, [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}, + "(bool a) -> bool"); + ArgTypeTestKernel::test( + "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);}, + "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}); + ArgTypeTestKernel::test( + "string1", [] (const std::string& v) {EXPECT_EQ("string1", v);}, + "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, + "(str a) -> str"); + ArgTypeTestKernel::test( + dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());}, + dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}); + ArgTypeTestKernel::test( + dummyTensor(TensorType1()), [] (const Tensor& v) {EXPECT_EQ(TensorType1(), v.type_id());}, + dummyTensor(TensorType2()), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}, + "(Tensor a) -> Tensor"); + + + // optional types (with has_value() == true) + ArgTypeTestKernel>::test( + c10::optional(1.5), [] (const c10::optional& v) {EXPECT_EQ(1.5, v.value());}, + c10::optional(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}); + ArgTypeTestKernel>::test( + c10::optional(1.5), [] (const c10::optional& v) {EXPECT_EQ(1.5, v.value());}, + c10::optional(2.5), [] (const IValue& v) {EXPECT_EQ(2.5, v.toDouble());}, + "(float? a) -> float?"); + ArgTypeTestKernel>::test( + c10::optional(1), [] (const c10::optional& v) {EXPECT_EQ(1, v.value());}, + c10::optional(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}); + ArgTypeTestKernel>::test( + c10::optional(1), [] (const c10::optional& v) {EXPECT_EQ(1, v.value());}, + c10::optional(2), [] (const IValue& v) {EXPECT_EQ(2, v.toInt());}, + "(int? a) -> int?"); + ArgTypeTestKernel>::test( + c10::optional(true), [] (const c10::optional& v) {EXPECT_EQ(true, v.value());}, + c10::optional(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}); + ArgTypeTestKernel>::test( + c10::optional(true), [] (const c10::optional& v) {EXPECT_EQ(true, v.value());}, + c10::optional(false), [] (const IValue& v) {EXPECT_EQ(false, v.toBool());}, + "(bool? a) -> bool?"); + ArgTypeTestKernel>::test( + c10::optional(false), [] (const c10::optional& v) {EXPECT_EQ(false, v.value());}, + c10::optional(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}); + ArgTypeTestKernel>::test( + c10::optional(false), [] (const c10::optional& v) {EXPECT_EQ(false, v.value());}, + c10::optional(true), [] (const IValue& v) {EXPECT_EQ(true, v.toBool());}, + "(bool? a) -> bool?"); + ArgTypeTestKernel>::test( + c10::optional("string1"), [] (const c10::optional& v) {EXPECT_EQ("string1", v.value());}, + c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}); + ArgTypeTestKernel>::test( + c10::optional("string1"), [] (const c10::optional& v) {EXPECT_EQ("string1", v.value());}, + c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, + "(str? a) -> str?"); + ArgTypeTestKernel>::test( + c10::optional(dummyTensor(TensorType1())), [] (const c10::optional& v) {EXPECT_EQ(TensorType1(), v.value().type_id());}, + c10::optional(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}); + ArgTypeTestKernel>::test( + c10::optional(dummyTensor(TensorType1())), [] (const c10::optional& v) {EXPECT_EQ(TensorType1(), v.value().type_id());}, + c10::optional(dummyTensor(TensorType2())), [] (const IValue& v) {EXPECT_EQ(TensorType2(), v.toTensor().type_id());}, + "(Tensor? a) -> Tensor?"); + + + // optional types (with has_value() == false) + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + "(float? a) -> float?"); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + "(int? a) -> int?"); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + "(bool? a) -> bool?"); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + "(bool? a) -> bool?"); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + "(str? a) -> str?"); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}); + ArgTypeTestKernel>::test( + c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + "(Tensor? a) -> Tensor?"); + + + // list types (with empty list) + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());}); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toDoubleListRef().size());}, + "(float[] a) -> float[]"); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toIntListRef().size());}, + "(int[] a) -> int[]"); + // TODO Converting std::vector to ArrayRef doesn't work, so we + // need to find an alternative + // ArgTypeTestKernel, std::vector>::test( + // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}); + // ArgTypeTestKernel, std::vector>::test( + // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}, + // "(bool[] a) -> bool[]"); + // ArgTypeTestKernel, std::vector>::test( + // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}); + // ArgTypeTestKernel, std::vector>::test( + // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toBoolListRef().size());}, + // "(bool[] a) -> bool[]"); + // TODO We currently don't support str[] (i.e. string list) as type. Do we want to? + // ArgTypeTestKernel, std::vector>::test( + // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());}); + // ArgTypeTestKernel, std::vector>::test( + // c10::ArrayRef(), [] (c10::ArrayRef v) {EXPECT_EQ(0, v.size());}, + // std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toStringListRef().size());}, + // "(str[] a) -> str[]"); + + + // list types (with non-empty list) + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef({1.5, 2.5}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1.5, 2.5}), v);}, + std::vector({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector({3.5, 4.5}), v.toDoubleListRef());}); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef({1.5, 2.5}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1.5, 2.5}), v);}, + std::vector({3.5, 4.5}), [] (const IValue& v) {EXPECT_EQ(std::vector({3.5, 4.5}), v.toDoubleListRef());}, + "(float[] a) -> float[]"); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef({1, 2}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v);}, + std::vector({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef({1, 2}), [] (c10::ArrayRef v) {EXPECT_EQ(c10::ArrayRef({1, 2}), v);}, + std::vector({3, 4}), [] (const IValue& v) {EXPECT_EQ(std::vector({3, 4}), v.toIntListRef());}, + "(int[] a) -> int[]"); + // TODO When fixing bool[] and str[] (see above), also add them here + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef v) { + EXPECT_EQ(2, v.size()); + EXPECT_EQ(TensorType1(), v[0].type_id()); + EXPECT_EQ(TensorType2(), v[1].type_id()); + }, + std::vector({dummyTensor(TensorType2()), dummyTensor(TensorType1())}), [] (const IValue& v) { + EXPECT_EQ(2, v.toTensorListRef().size()); + EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id()); + EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id()); + }); + ArgTypeTestKernel, std::vector>::test( + c10::ArrayRef({dummyTensor(TensorType1()), dummyTensor(TensorType2())}), [] (c10::ArrayRef v) { + EXPECT_EQ(2, v.size()); + EXPECT_EQ(TensorType1(), v[0].type_id()); + EXPECT_EQ(TensorType2(), v[1].type_id()); + }, + std::vector({dummyTensor(TensorType2()), dummyTensor(TensorType1())}), [] (const IValue& v) { + EXPECT_EQ(2, v.toTensorListRef().size()); + EXPECT_EQ(TensorType2(), v.toTensorListRef()[0].type_id()); + EXPECT_EQ(TensorType1(), v.toTensorListRef()[1].type_id()); + }, + "(Tensor[] a) -> Tensor[]"); + + + // TODO Do we want to support list of optional / optional of list ? + + // TODO Add tests for dict types +} + } diff --git a/aten/src/ATen/core/op_registration/test_helpers.h b/aten/src/ATen/core/op_registration/test_helpers.h index ecb7dc8..a678741 100644 --- a/aten/src/ATen/core/op_registration/test_helpers.h +++ b/aten/src/ATen/core/op_registration/test_helpers.h @@ -8,9 +8,38 @@ #include #include +namespace detail { +// InputToIValue takes a value and converts it to an IValue to be put on a stack. +template +struct InputToIValue final { + template + static c10::IValue call(T_&& v) { + return c10::IValue(std::forward(v)); + } +}; +template +struct InputToIValue> final { + template + static c10::IValue call(T_&& v) { + if (v.has_value()) { + return c10::IValue(std::move(*v)); + } else { + return c10::IValue(); + } + } +}; +template +struct InputToIValue> final { + template + static c10::IValue call(T_&& v) { + return c10::IValue(v.vec()); + } +}; +} + template inline std::vector makeStack(Inputs&&... inputs) { - return {std::forward(inputs)...}; + return {detail::InputToIValue>::call(std::forward(inputs))...}; } inline at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) { diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 6acb532..989e121 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -80,9 +80,13 @@ class ArrayRef final { : Data(Vec.data()), Length(Vec.size()) {} /// Construct an ArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for std::vector, + // because ArrayRef can't work on a std::vector bitfield. template /* implicit */ ArrayRef(const std::vector& Vec) - : Data(Vec.data()), Length(Vec.size()) {} + : Data(Vec.data()), Length(Vec.size()) { + static_assert(!std::is_same::value, "ArrayRef cannot be constructed from a std::vector bitfield."); + } /// Construct an ArrayRef from a std::array template