[llvm][ADT] Allow returning `std::nullopt` in TypeSwitch
authorMarkus Böck <markus.boeck02@gmail.com>
Sat, 17 Dec 2022 21:42:30 +0000 (22:42 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Sat, 17 Dec 2022 23:02:03 +0000 (00:02 +0100)
Returning `std::nullopt` from the case of a `TypeSwitch` yields broken results, by either falling through to another case, or falling of the switch entirely and hitting an assertion. This is simply due to the use of `operator=` of what is now `std::optional`, which has an overload specifically for `std::nullopt`, causing the internal optional, used for the `TypeSwitch` result to be reset, instead of a value being constructed from the `std::nullopt`.

The fix is to simply use the `emplace` method of `std::optional`, causing a value to always be constructed from the value returned by the case function.

Differential Revision: https://reviews.llvm.org/D140265

llvm/include/llvm/ADT/TypeSwitch.h
llvm/unittests/ADT/TypeSwitchTest.cpp

index f9323a5..10a2d48 100644 (file)
@@ -119,7 +119,7 @@ public:
 
     // Check to see if CaseT applies to 'value'.
     if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
-      result = caseFn(caseValue);
+      result.emplace(caseFn(caseValue));
     return *this;
   }
 
index 442ac19..c54b798 100644 (file)
@@ -86,3 +86,31 @@ TEST(TypeSwitchTest, CasesVoid) {
   EXPECT_EQ(0, translate(DerivedD()));
   EXPECT_EQ(-1, translate(DerivedE()));
 }
+
+TEST(TypeSwitchTest, CaseOptional) {
+  auto translate = [](auto value) {
+    return TypeSwitch<Base *, std::optional<int>>(&value)
+        .Case([](DerivedA *) { return 0; })
+        .Case([](DerivedC *) { return std::nullopt; })
+        .Default([](Base *) { return -1; });
+  };
+  EXPECT_EQ(0, translate(DerivedA()));
+  EXPECT_EQ(std::nullopt, translate(DerivedC()));
+  EXPECT_EQ(-1, translate(DerivedD()));
+  EXPECT_EQ(std::nullopt,
+            (TypeSwitch<Base *, std::optional<int>>(nullptr).Default(
+                [](Base *) { return std::nullopt; })));
+}
+
+TEST(TypeSwitchTest, CasesOptional) {
+  auto translate = [](auto value) {
+    return TypeSwitch<Base *, std::optional<int>>(&value)
+        .Case<DerivedB, DerivedC>([](auto *) { return std::nullopt; })
+        .Case([](DerivedA *) { return 0; })
+        .Default([](Base *) { return -1; });
+  };
+  EXPECT_EQ(0, translate(DerivedA()));
+  EXPECT_EQ(std::nullopt, translate(DerivedB()));
+  EXPECT_EQ(std::nullopt, translate(DerivedC()));
+  EXPECT_EQ(-1, translate(DerivedD()));
+}