From a456e1e19669fb3c17de75456570aab911f9476f Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 18 Apr 2019 02:00:49 -0700 Subject: [PATCH] Add either type (#19285) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19285 The either type is a tagged union with two members. This is going to be used in a diff stacked on top to allow a function to return one of two types. Also, generally, either is a great pattern for returning value_or_error from a function without using exceptions and we could use this class for that later. Reviewed By: dzhulgakov Differential Revision: D14931923 fbshipit-source-id: 7d1dd77b3e5b655f331444394dcdeab24772ab3a --- c10/test/CMakeLists.txt | 2 +- c10/test/util/either_test.cpp | 1246 +++++++++++++++++++++++++++++++++++++++++ c10/util/either.h | 213 +++++++ cmake/Dependencies.cmake | 4 +- 4 files changed, 1462 insertions(+), 3 deletions(-) create mode 100644 c10/test/util/either_test.cpp create mode 100644 c10/util/either.h diff --git a/c10/test/CMakeLists.txt b/c10/test/CMakeLists.txt index e711fd6..ec21838 100644 --- a/c10/test/CMakeLists.txt +++ b/c10/test/CMakeLists.txt @@ -6,7 +6,7 @@ if (BUILD_TEST) get_filename_component(test_file_name ${test_src} NAME_WE) set(test_name "c10_${test_file_name}") add_executable(${test_name} "${test_src}") - target_link_libraries(${test_name} c10 gtest_main) + target_link_libraries(${test_name} c10 gmock gtest gtest_main) add_test(NAME ${test_name} COMMAND $) if (INSTALL_TEST) install(TARGETS ${test_name} DESTINATION test) diff --git a/c10/test/util/either_test.cpp b/c10/test/util/either_test.cpp new file mode 100644 index 0000000..eb751ea --- /dev/null +++ b/c10/test/util/either_test.cpp @@ -0,0 +1,1246 @@ +// Originally taken from https://raw.githubusercontent.com/cryfs/cryfs/14ad22570ddacef22d5ff139cdff68a54fc8234d/test/cpp-utils/either_test.cpp + +#include +#include +#include +#include +#include +#include + +using std::string; +using std::vector; +using std::pair; +using std::tuple; +using std::ostringstream; +using c10::either; +using c10::make_left; +using c10::make_right; + +namespace { +class MovableOnly final { +public: + explicit MovableOnly(int value): _value(value) {} + MovableOnly(const MovableOnly&) = delete; + MovableOnly& operator=(const MovableOnly&) = delete; + + MovableOnly(MovableOnly&& rhs): _value(rhs._value) { + rhs._value = 0; + } + + MovableOnly& operator=(MovableOnly&& rhs) { + _value = rhs._value; + rhs._value = 0; + return *this; + } + + int value() const { + return _value; + } + +private: + int _value; +}; + +bool operator==(const MovableOnly& lhs, const MovableOnly& rhs) { + return lhs.value() == rhs.value(); +} + +template +void test_with_matrix(std::vector)>> setups, std::vector> expectations) { + for (const auto& setup: setups) { + for (const auto& expectation: expectations) { + setup(expectation); + } + } +} + +template +std::vector&)>> EXPECT_IS_LEFT(const Left& expected) { + return { + [&] (either& obj) { + EXPECT_TRUE(obj.is_left()); + }, [&] (either& obj) { + EXPECT_FALSE(obj.is_right()); + }, [&] (either& obj) { + EXPECT_EQ(expected, obj.left()); + }, [&] (either& obj) { + EXPECT_EQ(expected, std::move(obj).left()); + }, [&] (either& obj) { + EXPECT_ANY_THROW(obj.right()); + }, [&] (either& obj) { + EXPECT_ANY_THROW(std::move(obj).right()); + } + }; +} + +template +std::vector&)>> EXPECT_IS_RIGHT(const Right& expected) { + return { + [&] (either& obj) { + EXPECT_FALSE(obj.is_left()); + }, [&] (either& obj) { + EXPECT_TRUE(obj.is_right()); + }, [&] (either& obj) { + EXPECT_EQ(expected, obj.right()); + }, [&] (either& obj) { + EXPECT_EQ(expected, std::move(obj).right()); + }, [&] (either& obj) { + EXPECT_ANY_THROW(obj.left()); + }, [&] (either& obj) { + EXPECT_ANY_THROW(std::move(obj).left()); + } + }; +} + +template +std::vector> EXPECT_IS(const Value& v) { + return { + [&] (Value& obj) { + return obj == v; + } + }; +} + +template +struct StoreWith1ByteFlag { + T val; + char flag; +}; + +template +void TestSpaceUsage() { + EXPECT_EQ(std::max(sizeof(StoreWith1ByteFlag), sizeof(StoreWith1ByteFlag)), sizeof(either)); +} +} + +TEST(EitherTest, SpaceUsage) { + TestSpaceUsage(); + TestSpaceUsage(); + TestSpaceUsage(); + TestSpaceUsage(); + TestSpaceUsage>(); +} + +TEST(EitherTest, givenLeft) { + test_with_matrix({ + [] (std::function&)> test) { + either a(4); + test(a); + }, [] (std::function&)> test) { + either a = 4; + test(a); + }, + }, + EXPECT_IS_LEFT(4) + ); +} + +TEST(EitherTest, givenRight) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + test(a); + }, [] (std::function&)> test) { + either a = string("4"); + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenMakeLeft) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(4); + test(a); + }, [] (std::function&)> test) { + auto a = make_left(4); + test(a); + }, + }, + EXPECT_IS_LEFT(4) + ); +} + +TEST(EitherTest, givenMakeLeftWithSameType) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(4); + test(a); + }, [] (std::function&)> test) { + auto a = make_left(4); + test(a); + }, + }, + EXPECT_IS_LEFT(4) + ); +} + +TEST(EitherTest, givenMakeRight) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right("4"); + test(a); + }, [] (std::function&)> test) { + auto a = make_right("4"); + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenMakeRightWithSameType) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right("4"); + test(a); + }, [] (std::function&)> test) { + auto a = make_right("4"); + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenMovableOnlyMakeLeft) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(3); + test(a); + }, [] (std::function&)> test) { + auto a = make_left(3); + test(a); + }, + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenMovableOnlyMakeRight) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right(3); + test(a); + }, [] (std::function&)> test) { + auto a = make_right(3); + test(a); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenMultiParamMakeLeft) { + test_with_matrix({ + [] (std::function, string>&)> test) { + either, string> a = make_left, string>(5, 6); + test(a); + }, [] (std::function, string>&)> test) { + auto a = make_left, string>(5, 6); + test(a); + }, + }, + EXPECT_IS_LEFT, string>(pair(5, 6)) + ); +} + +TEST(EitherTest, givenMultiParamMakeRight) { + test_with_matrix({ + [] (std::function>&)> test) { + either> a = make_right>(5, 6); + test(a); + }, [] (std::function>&)> test) { + auto a = make_right>(5, 6); + test(a); + } + }, + EXPECT_IS_RIGHT>(pair(5, 6)) + ); +} + +TEST(EitherTest, givenLeftCopyConstructedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + string a = "4"; + either b(a); + test(b); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyConstructedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + string a = "4"; + either b(a); + test(a); + } + }, + EXPECT_IS("4") + ); +} + +TEST(EitherTest, givenRightCopyConstructedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + string a = "4"; + either b(a); + test(b); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenRightCopyConstructedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + string a = "4"; + either b(a); + test(a); + } + }, + EXPECT_IS("4") + ); +} + +TEST(EitherTest, givenLeftMoveConstructedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + MovableOnly a(3); + either b(std::move(a)); + test(b); + } + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenLeftMoveConstructedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + MovableOnly a(3); + either b(std::move(a)); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenRightMoveConstructedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + MovableOnly a(3); + either b(std::move(a)); + test(b); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenRightMoveConstructedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + MovableOnly a(3); + either b(std::move(a)); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenLeftCopyAssignedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + string a = "4"; + either b(2); + b = a; + test(b); + }, [] (std::function&)> test) { + string a = "4"; + either b("2"); + b = a; + test(b); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyAssignedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + string a = "4"; + either b(2); + b = a; + test(a); + }, [] (std::function test) { + string a = "4"; + either b("2"); + b = a; + test(a); + } + }, + EXPECT_IS("4") + ); +} + +TEST(EitherTest, givenRightCopyAssignedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + string a = "4"; + either b(2); + b = a; + test(b); + }, [] (std::function&)> test) { + string a = "4"; + either b("2"); + b = a; + test(b); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenRightCopyAssignedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + string a = "4"; + either b(2); + b = a; + test(a); + }, [] (std::function test) { + string a = "4"; + either b("2"); + b = a; + test(a); + } + }, + EXPECT_IS("4") + ); +} + +TEST(EitherTest, givenLeftMoveAssignedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + MovableOnly a(3); + either b(2); + b = std::move(a); + test(b); + }, [] (std::function&)> test) { + MovableOnly a(3); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + } + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenLeftMoveAssignedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + MovableOnly a(3); + either b("2"); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, [] (std::function test) { + MovableOnly a(3); + either b(MovableOnly(0)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS(MovableOnly(0)) + ); +} + +TEST(EitherTest, givenRightMoveAssignedFromValue_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + MovableOnly a(3); + either b("2"); + b = std::move(a); + test(b); + }, [] (std::function&)> test) { + MovableOnly a(3); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenRightMoveAssignedFromValue_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function test) { + MovableOnly a(3); + either b("2"); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, [] (std::function test) { + MovableOnly a(3); + either b(MovableOnly(2)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenLeftCopyConstructed_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(a); + test(b); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyConstructed_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(a); + test(a); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyConstructed_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left("4"); + either b(a); + test(b); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyConstructed_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left("4"); + either b(a); + test(a); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenRightCopyConstructed_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(a); + test(b); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + + +TEST(EitherTest, givenRightCopyConstructed_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(a); + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenRightCopyConstructed_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right("4"); + either b(a); + test(b); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + + +TEST(EitherTest, givenRightCopyConstructed_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right("4"); + either b(a); + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenLeftMoveConstructed_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(std::move(a)); + test(b); + } + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenLeftMoveConstructed_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(std::move(a)); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenLeftMoveConstructed_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(MovableOnly(3)); + either b(std::move(a)); + test(b); + } + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenLeftMoveConstructed_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(MovableOnly(3)); + either b(std::move(a)); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenRightMoveConstructed_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(std::move(a)); + test(b); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenRightMoveConstructed_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(std::move(a)); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenRightMoveConstructed_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right(MovableOnly(3)); + either b(std::move(a)); + test(b); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenRightMoveConstructed_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right(MovableOnly(3)); + either b(std::move(a)); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenLeftCopyAssigned_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(b); + }, [] (std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(b); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyAssigned_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(a); + }, [] (std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(a); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyAssigned_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left("4"); + either b = make_right("2"); + b = a; + test(b); + }, [] (std::function&)> test) { + either a = make_left("4"); + either b = make_left("2"); + b = a; + test(b); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenLeftCopyAssigned_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left("4"); + either b = make_right("2"); + b = a; + test(a); + }, [] (std::function&)> test) { + either a = make_left("4"); + either b = make_left("2"); + b = a; + test(a); + } + }, + EXPECT_IS_LEFT("4") + ); +} + +TEST(EitherTest, givenRightCopyAssigned_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(b); + }, [] (std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(b); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenRightCopyAssigned_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + either b(2); + b = a; + test(a); + }, [] (std::function&)> test) { + either a("4"); + either b("2"); + b = a; + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenRightCopyAssigned_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right("4"); + either b = make_left("2"); + b = a; + test(b); + }, [] (std::function&)> test) { + either a = make_right("4"); + either b = make_right("2"); + b = a; + test(b); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenRightCopyAssigned_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right("4"); + either b = make_left("2"); + b = a; + test(a); + }, [] (std::function&)> test) { + either a = make_right("4"); + either b = make_right("2"); + b = a; + test(a); + } + }, + EXPECT_IS_RIGHT("4") + ); +} + +TEST(EitherTest, givenLeftMoveAssigned_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(2); + b = std::move(a); + test(b); + }, [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + } + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenLeftMoveAssigned_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenLeftMoveAssigned_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(3); + either b = make_right(2); + b = std::move(a); + test(b); + }, [] (std::function&)> test) { + either a = make_left(3); + either b = make_left(2); + b = std::move(a); + test(b); + } + }, + EXPECT_IS_LEFT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenLeftMoveAssigned_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_left(3); + either b = make_right(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, [] (std::function&)> test) { + either a = make_left(3); + either b = make_left(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_LEFT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenRightMoveAssigned_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b("2"); + b = std::move(a); + test(b); + }, [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(b); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenRightMoveAssigned_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a(MovableOnly(3)); + either b("2"); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, [] (std::function&)> test) { + either a(MovableOnly(3)); + either b(MovableOnly(2)); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenRightMoveAssigned_withSameType_thenNewIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right(3); + either b = make_left(2); + b = std::move(a); + test(b); + }, [] (std::function&)> test) { + either a = make_right(3); + either b = make_right(2); + b = std::move(a); + test(b); + } + }, + EXPECT_IS_RIGHT(MovableOnly(3)) + ); +} + +TEST(EitherTest, givenRightMoveAssigned_withSameType_thenOldIsCorrect) { + test_with_matrix({ + [] (std::function&)> test) { + either a = make_right(3); + either b = make_left(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + }, [] (std::function&)> test) { + either a = make_right(3); + either b = make_right(2); + b = std::move(a); + test(a); // NOLINT(bugprone-use-after-move) + } + }, + EXPECT_IS_RIGHT(MovableOnly(0)) // 0 is moved-from value + ); +} + +TEST(EitherTest, givenLeft_whenModified_thenValueIsChanged) { + test_with_matrix({ + [] (std::function&)> test) { + either a(4); + a.left() = 5; + test(a); + }, [] (std::function&)> test) { + either a(4); + a.left() = 5; + test(a); + } + }, + EXPECT_IS_LEFT(5) + ); +} + +TEST(EitherTest, givenRight_whenModified_thenValueIsChanged) { + test_with_matrix({ + [] (std::function&)> test) { + either a("4"); + a.right() = "5"; + test(a); + }, [] (std::function&)> test) { + either a("4"); + a.right() = "5"; + test(a); + } + }, + EXPECT_IS_RIGHT("5") + ); +} + +TEST(EitherTest, canEmplaceConstructLeft) { + test_with_matrix({ + [] (std::function, tuple>&)> test) { + either, tuple> a(2, 3); + test(a); + } + }, + EXPECT_IS_LEFT, tuple>(tuple(2, 3)) + ); +} + +TEST(EitherTest, canEmplaceConstructRight) { + test_with_matrix({ + [] (std::function, tuple>&)> test) { + either, tuple> a(2, "3", 4); + test(a); + } + }, + EXPECT_IS_RIGHT, tuple>(tuple(2, "3", 4)) + ); +} + +TEST(EitherTest, givenEqualLefts_thenAreEqual) { + either a("3"); + either b("3"); + EXPECT_TRUE(a == b); +} + +TEST(EitherTest, givenEqualLefts_thenAreNotUnequal) { + either a("3"); + either b("3"); + EXPECT_FALSE(a != b); +} + +TEST(EitherTest, givenEqualRights_thenAreEqual) { + either a(3); + either b(3); + EXPECT_TRUE(a == b); +} + +TEST(EitherTest, givenEqualRights_thenAreNotUnequal) { + either a(3); + either b(3); + EXPECT_FALSE(a != b); +} + +TEST(EitherTest, givenLeftAndRight_thenAreNotEqual) { + either a("3"); + either b(3); + EXPECT_FALSE(a == b); + EXPECT_FALSE(b == a); +} + +TEST(EitherTest, givenLeftAndRight_thenAreUnequal) { + either a("3"); + either b(3); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b != a); +} + +TEST(EitherTest, OutputLeft) { + ostringstream str; + str << either("mystring"); + EXPECT_EQ("Left(mystring)", str.str()); +} + +TEST(EitherTest, OutputRight) { + ostringstream str; + str << either("mystring"); + EXPECT_EQ("Right(mystring)", str.str()); +} + +TEST(EitherTest, givenLeftAndRightWithSameType_thenAreNotEqual) { + either a = make_left("3"); + either b = make_right("3"); + EXPECT_FALSE(a == b); + EXPECT_FALSE(b == a); +} + +TEST(EitherTest, givenLeftAndRightWithSameType_thenAreUnequal) { + either a = make_left("3"); + either b = make_right("3"); + EXPECT_TRUE(a != b); + EXPECT_TRUE(b != a); +} + + +namespace { +class DestructorCallback { +public: + MOCK_CONST_METHOD0(call, void()); + + void EXPECT_CALLED(int times = 1) { + EXPECT_CALL(*this, call()).Times(times); + } +}; +class ClassWithDestructorCallback { +public: + ClassWithDestructorCallback(const DestructorCallback *destructorCallback) : _destructorCallback(destructorCallback) {} + ClassWithDestructorCallback(const ClassWithDestructorCallback &rhs): _destructorCallback(rhs._destructorCallback) {} + + ~ClassWithDestructorCallback() { + _destructorCallback->call(); + } + +private: + const DestructorCallback *_destructorCallback; + + ClassWithDestructorCallback &operator=(const ClassWithDestructorCallback &rhs) = delete; +}; +class OnlyMoveableClassWithDestructorCallback { +public: + OnlyMoveableClassWithDestructorCallback(const DestructorCallback *destructorCallback) : _destructorCallback(destructorCallback) { } + OnlyMoveableClassWithDestructorCallback(OnlyMoveableClassWithDestructorCallback &&source): _destructorCallback(source._destructorCallback) {} + + ~OnlyMoveableClassWithDestructorCallback() { + _destructorCallback->call(); + } + +private: + C10_DISABLE_COPY_AND_ASSIGN(OnlyMoveableClassWithDestructorCallback); + const DestructorCallback *_destructorCallback; +}; + +} + +TEST(EitherTest_Destructor, LeftDestructorIsCalled) { + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED(2); //Once for the temp object, once when the either class destructs + + ClassWithDestructorCallback temp(&destructorCallback); + either var = temp; +} + +TEST(EitherTest_Destructor, RightDestructorIsCalled) { + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED(2); //Once for the temp object, once when the either class destructs + + ClassWithDestructorCallback temp(&destructorCallback); + either var = temp; +} + +TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterCopying) { + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + + ClassWithDestructorCallback temp(&destructorCallback); + either var1 = temp; + either var2 = var1; +} + +TEST(EitherTest_Destructor, RightDestructorIsCalledAfterCopying) { + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + + ClassWithDestructorCallback temp(&destructorCallback); + either var1 = temp; + either var2 = var1; +} + +TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterMoving) { + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + + OnlyMoveableClassWithDestructorCallback temp(&destructorCallback); + either var1 = std::move(temp); + either var2 = std::move(var1); +} + +TEST(EitherTest_Destructor, RightDestructorIsCalledAfterMoving) { + DestructorCallback destructorCallback; + destructorCallback.EXPECT_CALLED(3); //Once for the temp object, once for var1 and once for var2 + + OnlyMoveableClassWithDestructorCallback temp(&destructorCallback); + either var1 = std::move(temp); + either var2 = std::move(var1); +} + +TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterAssignment) { + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + + ClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = temp1; + ClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = temp2; + var1 = var2; +} + +TEST(EitherTest_Destructor, RightDestructorIsCalledAfterAssignment) { + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + + ClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = temp1; + ClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = temp2; + var1 = var2; +} + +TEST(EitherTest_Destructor, LeftDestructorIsCalledAfterMoveAssignment) { + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + + OnlyMoveableClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = std::move(temp1); + OnlyMoveableClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = std::move(temp2); + var1 = std::move(var2); +} + +TEST(EitherTest_Destructor, RightDestructorIsCalledAfterMoveAssignment) { + DestructorCallback destructorCallback1; + DestructorCallback destructorCallback2; + destructorCallback1.EXPECT_CALLED(2); //Once for the temp1 object, once at the assignment + destructorCallback2.EXPECT_CALLED(3); //Once for the temp2 object, once in destructor of var2, once in destructor of var1 + + OnlyMoveableClassWithDestructorCallback temp1(&destructorCallback1); + either var1 = std::move(temp1); + OnlyMoveableClassWithDestructorCallback temp2(&destructorCallback2); + either var2 = std::move(temp2); + var1 = std::move(var2); +} diff --git a/c10/util/either.h b/c10/util/either.h new file mode 100644 index 0000000..1a3757e --- /dev/null +++ b/c10/util/either.h @@ -0,0 +1,213 @@ +// Originally taken from +// https://github.com/cryfs/cryfs/blob/14ad22570ddacef22d5ff139cdff68a54fc8234d/src/cpp-utils/either.h + +#pragma once + +#include +#include +#include +#include + +namespace c10 { +/** + * either is a tagged union that holds either an object of type A + * or an object of type B. + */ +template +class either final { + public: + template < + class Head, + class... Tail, + c10::guts::enable_if_t< + std::is_constructible::value && + !std::is_constructible::value>* = nullptr> + either(Head&& construct_left_head_arg, Tail&&... construct_left_tail_args) + : _side(Side::left) { + _construct_left( + std::forward(construct_left_head_arg), + std::forward(construct_left_tail_args)...); + } + + template < + class Head, + class... Tail, + c10::guts::enable_if_t< + !std::is_constructible::value && + std::is_constructible::value>* = nullptr> + either(Head&& construct_right_head_arg, Tail&&... construct_right_tail_args) + : _side(Side::right) { + _construct_right( + std::forward(construct_right_head_arg), + std::forward(construct_right_tail_args)...); + } + + either(const either& rhs) : _side(rhs._side) { + if (_side == Side::left) { + _construct_left( + rhs._left); // NOLINT(cppcoreguidelines-pro-type-union-access) + } else { + _construct_right( + rhs._right); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + } + + either(either&& rhs) noexcept : _side(rhs._side) { + if (_side == Side::left) { + _construct_left(std::move( + rhs._left)); // NOLINT(cppcoreguidelines-pro-type-union-access) + } else { + _construct_right(std::move( + rhs._right)); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + } + + ~either() { + _destruct(); + } + + either& operator=(const either& rhs) { + _destruct(); + _side = rhs._side; + if (_side == Side::left) { + _construct_left( + rhs._left); // NOLINT(cppcoreguidelines-pro-type-union-access) + } else { + _construct_right( + rhs._right); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + return *this; + } + + either& operator=(either&& rhs) { + _destruct(); + _side = rhs._side; + if (_side == Side::left) { + _construct_left(std::move( + rhs._left)); // NOLINT(cppcoreguidelines-pro-type-union-access) + } else { + _construct_right(std::move( + rhs._right)); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + return *this; + } + + bool is_left() const noexcept { + return _side == Side::left; + } + + bool is_right() const noexcept { + return _side == Side::right; + } + + const Left& left() const& { + if (!is_left()) { + throw std::logic_error( + "Tried to get left side of an either which is right."); + } + return _left; // NOLINT(cppcoreguidelines-pro-type-union-access) + } + Left& left() & { + return const_cast( + const_cast*>(this)->left()); + } + Left&& left() && { + return std::move(left()); + } + + const Right& right() const& { + if (!is_right()) { + throw std::logic_error( + "Tried to get right side of an either which is left."); + } + return _right; // NOLINT(cppcoreguidelines-pro-type-union-access) + } + Right& right() & { + return const_cast( + const_cast*>(this)->right()); + } + Right&& right() && { + return std::move(right()); + } + + private: + union { + Left _left; + Right _right; + }; + enum class Side : uint8_t { left, right } _side; + + explicit either(Side side) noexcept : _side(side) {} + + template + void _construct_left(Args&&... args) { + new (&_left) Left(std::forward( + args)...); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + template + void _construct_right(Args&&... args) { + new (&_right) Right(std::forward( + args)...); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + void _destruct() noexcept { + if (_side == Side::left) { + _left.~Left(); // NOLINT(cppcoreguidelines-pro-type-union-access) + } else { + _right.~Right(); // NOLINT(cppcoreguidelines-pro-type-union-access) + } + } + + template + friend either make_left(Args&&... args); + + template + friend either make_right(Args&&... args); +}; + +template +inline bool operator==( + const either& lhs, + const either& rhs) { + if (lhs.is_left() != rhs.is_left()) { + return false; + } + if (lhs.is_left()) { + return lhs.left() == rhs.left(); + } else { + return lhs.right() == rhs.right(); + } +} + +template +inline bool operator!=( + const either& lhs, + const either& rhs) { + return !operator==(lhs, rhs); +} + +template +inline std::ostream& operator<<( + std::ostream& stream, + const either& value) { + if (value.is_left()) { + stream << "Left(" << value.left() << ")"; + } else { + stream << "Right(" << value.right() << ")"; + } + return stream; +} + +template +inline either make_left(Args&&... args) { + either result(either::Side::left); + result._construct_left(std::forward(args)...); + return result; +} + +template +inline either make_right(Args&&... args) { + either result(either::Side::right); + result._construct_right(std::forward(args)...); + return result; +} +} // namespace c10 diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 706bbf7..ebfc2db 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -113,7 +113,7 @@ elseif(BLAS STREQUAL "MKL") message(STATUS "MKL include directory: ${MKL_INCLUDE_DIR}") message(STATUS "MKL OpenMP type: ${MKL_OPENMP_TYPE}") message(STATUS "MKL OpenMP library: ${MKL_OPENMP_LIBRARY}") - include_directories(SYSTEM ${MKL_INCLUDE_DIR}) + include_directories(AFTER SYSTEM ${MKL_INCLUDE_DIR}) list(APPEND Caffe2_PUBLIC_DEPENDENCY_LIBS caffe2::mkl) set(CAFFE2_USE_MKL ON) else() @@ -1302,7 +1302,7 @@ if (NOT BUILD_ATEN_MOBILE) INCLUDE(${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake) IF(MKLDNN_FOUND) SET(AT_MKLDNN_ENABLED 1) - INCLUDE_DIRECTORIES(BEFORE SYSTEM ${MKLDNN_INCLUDE_DIR}) + INCLUDE_DIRECTORIES(AFTER SYSTEM ${MKLDNN_INCLUDE_DIR}) IF(BUILD_CAFFE2_OPS) SET(CAFFE2_USE_MKLDNN ON) LIST(APPEND Caffe2_PUBLIC_DEPENDENCY_LIBS caffe2::mkldnn) -- 2.7.4