Add a new utility class TypeSwitch to ADT.
authorRiver Riddle <riverriddle@google.com>
Tue, 17 Dec 2019 18:07:26 +0000 (10:07 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 17 Dec 2019 18:08:06 +0000 (10:08 -0800)
This class provides a simplified mechanism for defining a switch over a set of types using llvm casting functionality. More specifically, this allows for defining a switch over a value of type T where each case corresponds to a type(CaseT) that can be used with dyn_cast<CaseT>(...). An example is shown below:

// Traditional piece of code:
Operation *op = ...;
if (auto constant = dyn_cast<ConstantOp>(op))
  ...;
else if (auto return = dyn_cast<ReturnOp>(op))
  ...;
else
  ...;

// New piece of code:
Operation *op = ...;
TypeSwitch<Operation *>(op)
  .Case<ConstantOp>([](ConstantOp constant) { ... })
  .Case<ReturnOp>([](ReturnOp return) { ... })
  .Default([](Operation *op) { ... });

Aside from the above, TypeSwitch supports return values, void return, multiple types per case, etc. The usability is intended to be very similar to StringSwitch.

(Using c++14 template lambdas makes everything even nicer)
More complex example of how this makes certain things easier:
LogicalResult process(Constant op);
LogicalResult process(ReturnOp op);
LogicalResult process(FuncOp op);

TypeSwitch<Operation *, LogicalResult>(op)
  .Case<ConstantOp, ReturnOp, FuncOp>([](auto op) { return process(op); })
  .Default([](Operation *op) { return op->emitError() << "could not be processed"; });

PiperOrigin-RevId: 286003613

mlir/include/mlir/ADT/TypeSwitch.h [new file with mode: 0644]
mlir/include/mlir/Support/STLExtras.h
mlir/unittests/ADT/CMakeLists.txt [new file with mode: 0644]
mlir/unittests/ADT/TypeSwitchTest.cpp [new file with mode: 0644]
mlir/unittests/CMakeLists.txt

diff --git a/mlir/include/mlir/ADT/TypeSwitch.h b/mlir/include/mlir/ADT/TypeSwitch.h
new file mode 100644 (file)
index 0000000..75051b6
--- /dev/null
@@ -0,0 +1,185 @@
+//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+//  This file implements the TypeSwitch template, which mimics a switch()
+//  statement whose cases are type names.
+//
+//===-----------------------------------------------------------------------===/
+
+#ifndef MLIR_SUPPORT_TYPESWITCH_H
+#define MLIR_SUPPORT_TYPESWITCH_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/Optional.h"
+
+namespace mlir {
+namespace detail {
+
+template <typename DerivedT, typename T> class TypeSwitchBase {
+public:
+  TypeSwitchBase(const T &value) : value(value) {}
+  TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
+  ~TypeSwitchBase() = default;
+
+  /// TypeSwitchBase is not copyable.
+  TypeSwitchBase(const TypeSwitchBase &) = delete;
+  void operator=(const TypeSwitchBase &) = delete;
+  void operator=(TypeSwitchBase &&other) = delete;
+
+  /// Invoke a case on the derived class with multiple case types.
+  template <typename CaseT, typename CaseT2, typename... CaseTs,
+            typename CallableT>
+  DerivedT &Case(CallableT &&caseFn) {
+    DerivedT &derived = static_cast<DerivedT &>(*this);
+    return derived.template Case<CaseT>(caseFn)
+        .template Case<CaseT2, CaseTs...>(caseFn);
+  }
+
+  /// Invoke a case on the derived class, inferring the type of the Case from
+  /// the first input of the given callable.
+  /// Note: This inference rules for this overload are very simple: strip
+  ///       pointers and references.
+  template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
+    using Traits = FunctionTraits<std::decay_t<CallableT>>;
+    using CaseT = std::remove_cv_t<std::remove_pointer_t<
+        std::remove_reference_t<typename Traits::template arg_t<0>>>>;
+
+    DerivedT &derived = static_cast<DerivedT &>(*this);
+    return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
+  }
+
+protected:
+  /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
+  /// `CastT`.
+  template <typename ValueT, typename CastT>
+  using has_dyn_cast_t =
+      decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
+
+  /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
+  /// selected if `value` already has a suitable dyn_cast method.
+  template <typename CastT, typename ValueT>
+  static auto castValue(
+      ValueT value,
+      typename std::enable_if_t<
+          is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
+    return value.template dyn_cast<CastT>();
+  }
+
+  /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
+  /// selected if llvm::dyn_cast should be used.
+  template <typename CastT, typename ValueT>
+  static auto castValue(
+      ValueT value,
+      typename std::enable_if_t<
+          !is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
+    return dyn_cast<CastT>(value);
+  }
+
+  /// The root value we are switching on.
+  const T value;
+};
+} // end namespace detail
+
+/// This class implements a switch-like dispatch statement for a value of 'T'
+/// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
+/// if the root value isa<T>, the callable is invoked with the result of
+/// dyn_cast<T>() as a parameter.
+///
+/// Example:
+///  Operation *op = ...;
+///  LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
+///    .Case<ConstantOp>([](ConstantOp op) { ... })
+///    .Default([](Operation *op) { ... });
+///
+template <typename T, typename ResultT = void>
+class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
+public:
+  using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
+  using BaseT::BaseT;
+  using BaseT::Case;
+  TypeSwitch(TypeSwitch &&other) = default;
+
+  /// Add a case on the given type.
+  template <typename CaseT, typename CallableT>
+  TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
+    if (result)
+      return *this;
+
+    // Check to see if CaseT applies to 'value'.
+    if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
+      result = caseFn(caseValue);
+    return *this;
+  }
+
+  /// As a default, invoke the given callable within the root value.
+  template <typename CallableT>
+  LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) {
+    if (result)
+      return std::move(*result);
+    return defaultFn(this->value);
+  }
+
+  LLVM_NODISCARD
+  operator ResultT() {
+    assert(result && "Fell off the end of a type-switch");
+    return std::move(*result);
+  }
+
+private:
+  /// The pointer to the result of this switch statement, once known,
+  /// null before that.
+  Optional<ResultT> result;
+};
+
+/// Specialization of TypeSwitch for void returning callables.
+template <typename T>
+class TypeSwitch<T, void>
+    : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
+public:
+  using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
+  using BaseT::BaseT;
+  using BaseT::Case;
+  TypeSwitch(TypeSwitch &&other) = default;
+
+  /// Add a case on the given type.
+  template <typename CaseT, typename CallableT>
+  TypeSwitch<T, void> &Case(CallableT &&caseFn) {
+    if (foundMatch)
+      return *this;
+
+    // Check to see if any of the types apply to 'value'.
+    if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
+      caseFn(caseValue);
+      foundMatch = true;
+    }
+    return *this;
+  }
+
+  /// As a default, invoke the given callable within the root value.
+  template <typename CallableT> void Default(CallableT &&defaultFn) {
+    if (!foundMatch)
+      defaultFn(this->value);
+  }
+
+private:
+  /// A flag detailing if we have already found a match.
+  bool foundMatch = false;
+};
+} // end namespace mlir
+
+#endif // MLIR_SUPPORT_TYPESWITCH_H
index c98f925..9bae7ac 100644 (file)
@@ -344,6 +344,44 @@ template <typename ContainerTy> bool has_single_element(ContainerTy &&c) {
   auto it = std::begin(c), e = std::end(c);
   return it != e && std::next(it) == e;
 }
+
+//===----------------------------------------------------------------------===//
+//     Extra additions to <type_traits>
+//===----------------------------------------------------------------------===//
+
+/// This class provides various trait information about a callable object.
+///   * To access the number of arguments: Traits::num_args
+///   * To access the type of an argument: Traits::arg_t<i>
+///   * To access the type of the result: Traits::result_t<i>
+template <typename T, bool isClass = std::is_class<T>::value>
+struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
+
+/// Overload for class function types.
+template <typename ClassType, typename ReturnType, typename... Args>
+struct FunctionTraits<ReturnType (ClassType::*)(Args...) const, false> {
+  /// The number of arguments to this function.
+  enum { num_args = sizeof...(Args) };
+
+  /// The result type of this function.
+  using result_t = ReturnType;
+
+  /// The type of an argument to this function.
+  template <size_t i>
+  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
+};
+/// Overload for non-class function types.
+template <typename ReturnType, typename... Args>
+struct FunctionTraits<ReturnType (*)(Args...), false> {
+  /// The number of arguments to this function.
+  enum { num_args = sizeof...(Args) };
+
+  /// The result type of this function.
+  using result_t = ReturnType;
+
+  /// The type of an argument to this function.
+  template <size_t i>
+  using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
+};
 } // end namespace mlir
 
 // Allow tuples to be usable as DenseMap keys.
diff --git a/mlir/unittests/ADT/CMakeLists.txt b/mlir/unittests/ADT/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cb12262
--- /dev/null
@@ -0,0 +1,5 @@
+add_mlir_unittest(MLIRADTTests
+  TypeSwitchTest.cpp
+)
+
+target_link_libraries(MLIRADTTests PRIVATE MLIRSupport LLVMSupport)
diff --git a/mlir/unittests/ADT/TypeSwitchTest.cpp b/mlir/unittests/ADT/TypeSwitchTest.cpp
new file mode 100644 (file)
index 0000000..b6a78de
--- /dev/null
@@ -0,0 +1,97 @@
+//===- TypeSwitchTest.cpp - TypeSwitch unit tests -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/ADT/TypeSwitch.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+/// Utility classes to setup casting functionality.
+struct Base {
+  enum Kind { DerivedA, DerivedB, DerivedC, DerivedD, DerivedE };
+  Kind kind;
+};
+template <Base::Kind DerivedKind> struct DerivedImpl : Base {
+  DerivedImpl() : Base{DerivedKind} {}
+  static bool classof(const Base *base) { return base->kind == DerivedKind; }
+};
+struct DerivedA : public DerivedImpl<Base::DerivedA> {};
+struct DerivedB : public DerivedImpl<Base::DerivedB> {};
+struct DerivedC : public DerivedImpl<Base::DerivedC> {};
+struct DerivedD : public DerivedImpl<Base::DerivedD> {};
+struct DerivedE : public DerivedImpl<Base::DerivedE> {};
+} // end anonymous namespace
+
+TEST(StringSwitchTest, CaseResult) {
+  auto translate = [](auto value) {
+    return TypeSwitch<Base *, int>(&value)
+        .Case<DerivedA>([](DerivedA *) { return 0; })
+        .Case([](DerivedB *) { return 1; })
+        .Case([](DerivedC *) { return 2; })
+        .Default([](Base *) { return -1; });
+  };
+  EXPECT_EQ(0, translate(DerivedA()));
+  EXPECT_EQ(1, translate(DerivedB()));
+  EXPECT_EQ(2, translate(DerivedC()));
+  EXPECT_EQ(-1, translate(DerivedD()));
+}
+
+TEST(StringSwitchTest, CasesResult) {
+  auto translate = [](auto value) {
+    return TypeSwitch<Base *, int>(&value)
+        .Case<DerivedA, DerivedB, DerivedD>([](auto *) { return 0; })
+        .Case([](DerivedC *) { return 1; })
+        .Default([](Base *) { return -1; });
+  };
+  EXPECT_EQ(0, translate(DerivedA()));
+  EXPECT_EQ(0, translate(DerivedB()));
+  EXPECT_EQ(1, translate(DerivedC()));
+  EXPECT_EQ(0, translate(DerivedD()));
+  EXPECT_EQ(-1, translate(DerivedE()));
+}
+
+TEST(StringSwitchTest, CaseVoid) {
+  auto translate = [](auto value) {
+    int result = -2;
+    TypeSwitch<Base *>(&value)
+        .Case([&](DerivedA *) { result = 0; })
+        .Case([&](DerivedB *) { result = 1; })
+        .Case([&](DerivedC *) { result = 2; })
+        .Default([&](Base *) { result = -1; });
+    return result;
+  };
+  EXPECT_EQ(0, translate(DerivedA()));
+  EXPECT_EQ(1, translate(DerivedB()));
+  EXPECT_EQ(2, translate(DerivedC()));
+  EXPECT_EQ(-1, translate(DerivedD()));
+}
+
+TEST(StringSwitchTest, CasesVoid) {
+  auto translate = [](auto value) {
+    int result = -1;
+    TypeSwitch<Base *>(&value)
+        .Case<DerivedA, DerivedB, DerivedD>([&](auto *) { result = 0; })
+        .Case([&](DerivedC *) { result = 1; });
+    return result;
+  };
+  EXPECT_EQ(0, translate(DerivedA()));
+  EXPECT_EQ(0, translate(DerivedB()));
+  EXPECT_EQ(1, translate(DerivedC()));
+  EXPECT_EQ(0, translate(DerivedD()));
+  EXPECT_EQ(-1, translate(DerivedE()));
+}
index e80cc91..79a5297 100644 (file)
@@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
   add_unittest(MLIRUnitTests ${test_dirname} ${ARGN})
 endfunction()
 
+add_subdirectory(ADT)
 add_subdirectory(Dialect)
 add_subdirectory(IR)
 add_subdirectory(Pass)