From: River Riddle Date: Tue, 17 Dec 2019 18:07:26 +0000 (-0800) Subject: Add a new utility class TypeSwitch to ADT. X-Git-Tag: llvmorg-11-init~1466^2~52 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f44cf23297089dd4beb6f81a7fdda4e59466dcdb;p=platform%2Fupstream%2Fllvm.git Add a new utility class TypeSwitch to ADT. This class provides a simplified mechanism for defining a switch over a set of types using llvm casting functionality. More specifically, this allows for defining a switch over a value of type T where each case corresponds to a type(CaseT) that can be used with dyn_cast(...). An example is shown below: // Traditional piece of code: Operation *op = ...; if (auto constant = dyn_cast(op)) ...; else if (auto return = dyn_cast(op)) ...; else ...; // New piece of code: Operation *op = ...; TypeSwitch(op) .Case([](ConstantOp constant) { ... }) .Case([](ReturnOp return) { ... }) .Default([](Operation *op) { ... }); Aside from the above, TypeSwitch supports return values, void return, multiple types per case, etc. The usability is intended to be very similar to StringSwitch. (Using c++14 template lambdas makes everything even nicer) More complex example of how this makes certain things easier: LogicalResult process(Constant op); LogicalResult process(ReturnOp op); LogicalResult process(FuncOp op); TypeSwitch(op) .Case([](auto op) { return process(op); }) .Default([](Operation *op) { return op->emitError() << "could not be processed"; }); PiperOrigin-RevId: 286003613 --- diff --git a/mlir/include/mlir/ADT/TypeSwitch.h b/mlir/include/mlir/ADT/TypeSwitch.h new file mode 100644 index 0000000..75051b6 --- /dev/null +++ b/mlir/include/mlir/ADT/TypeSwitch.h @@ -0,0 +1,185 @@ +//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the TypeSwitch template, which mimics a switch() +// statement whose cases are type names. +// +//===-----------------------------------------------------------------------===/ + +#ifndef MLIR_SUPPORT_TYPESWITCH_H +#define MLIR_SUPPORT_TYPESWITCH_H + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { +namespace detail { + +template class TypeSwitchBase { +public: + TypeSwitchBase(const T &value) : value(value) {} + TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {} + ~TypeSwitchBase() = default; + + /// TypeSwitchBase is not copyable. + TypeSwitchBase(const TypeSwitchBase &) = delete; + void operator=(const TypeSwitchBase &) = delete; + void operator=(TypeSwitchBase &&other) = delete; + + /// Invoke a case on the derived class with multiple case types. + template + DerivedT &Case(CallableT &&caseFn) { + DerivedT &derived = static_cast(*this); + return derived.template Case(caseFn) + .template Case(caseFn); + } + + /// Invoke a case on the derived class, inferring the type of the Case from + /// the first input of the given callable. + /// Note: This inference rules for this overload are very simple: strip + /// pointers and references. + template DerivedT &Case(CallableT &&caseFn) { + using Traits = FunctionTraits>; + using CaseT = std::remove_cv_t>>>; + + DerivedT &derived = static_cast(*this); + return derived.template Case(std::forward(caseFn)); + } + +protected: + /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type + /// `CastT`. + template + using has_dyn_cast_t = + decltype(std::declval().template dyn_cast()); + + /// Attempt to dyn_cast the given `value` to `CastT`. This overload is + /// selected if `value` already has a suitable dyn_cast method. + template + static auto castValue( + ValueT value, + typename std::enable_if_t< + is_detected::value> * = nullptr) { + return value.template dyn_cast(); + } + + /// Attempt to dyn_cast the given `value` to `CastT`. This overload is + /// selected if llvm::dyn_cast should be used. + template + static auto castValue( + ValueT value, + typename std::enable_if_t< + !is_detected::value> * = nullptr) { + return dyn_cast(value); + } + + /// The root value we are switching on. + const T value; +}; +} // end namespace detail + +/// This class implements a switch-like dispatch statement for a value of 'T' +/// using dyn_cast functionality. Each `Case` takes a callable to be invoked +/// if the root value isa, the callable is invoked with the result of +/// dyn_cast() as a parameter. +/// +/// Example: +/// Operation *op = ...; +/// LogicalResult result = TypeSwitch(op) +/// .Case([](ConstantOp op) { ... }) +/// .Default([](Operation *op) { ... }); +/// +template +class TypeSwitch : public detail::TypeSwitchBase, T> { +public: + using BaseT = detail::TypeSwitchBase, T>; + using BaseT::BaseT; + using BaseT::Case; + TypeSwitch(TypeSwitch &&other) = default; + + /// Add a case on the given type. + template + TypeSwitch &Case(CallableT &&caseFn) { + if (result) + return *this; + + // Check to see if CaseT applies to 'value'. + if (auto caseValue = BaseT::template castValue(this->value)) + result = caseFn(caseValue); + return *this; + } + + /// As a default, invoke the given callable within the root value. + template + LLVM_NODISCARD ResultT Default(CallableT &&defaultFn) { + if (result) + return std::move(*result); + return defaultFn(this->value); + } + + LLVM_NODISCARD + operator ResultT() { + assert(result && "Fell off the end of a type-switch"); + return std::move(*result); + } + +private: + /// The pointer to the result of this switch statement, once known, + /// null before that. + Optional result; +}; + +/// Specialization of TypeSwitch for void returning callables. +template +class TypeSwitch + : public detail::TypeSwitchBase, T> { +public: + using BaseT = detail::TypeSwitchBase, T>; + using BaseT::BaseT; + using BaseT::Case; + TypeSwitch(TypeSwitch &&other) = default; + + /// Add a case on the given type. + template + TypeSwitch &Case(CallableT &&caseFn) { + if (foundMatch) + return *this; + + // Check to see if any of the types apply to 'value'. + if (auto caseValue = BaseT::template castValue(this->value)) { + caseFn(caseValue); + foundMatch = true; + } + return *this; + } + + /// As a default, invoke the given callable within the root value. + template void Default(CallableT &&defaultFn) { + if (!foundMatch) + defaultFn(this->value); + } + +private: + /// A flag detailing if we have already found a match. + bool foundMatch = false; +}; +} // end namespace mlir + +#endif // MLIR_SUPPORT_TYPESWITCH_H diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index c98f925..9bae7ac 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -344,6 +344,44 @@ template bool has_single_element(ContainerTy &&c) { auto it = std::begin(c), e = std::end(c); return it != e && std::next(it) == e; } + +//===----------------------------------------------------------------------===// +// Extra additions to +//===----------------------------------------------------------------------===// + +/// This class provides various trait information about a callable object. +/// * To access the number of arguments: Traits::num_args +/// * To access the type of an argument: Traits::arg_t +/// * To access the type of the result: Traits::result_t +template ::value> +struct FunctionTraits : public FunctionTraits {}; + +/// Overload for class function types. +template +struct FunctionTraits { + /// The number of arguments to this function. + enum { num_args = sizeof...(Args) }; + + /// The result type of this function. + using result_t = ReturnType; + + /// The type of an argument to this function. + template + using arg_t = typename std::tuple_element>::type; +}; +/// Overload for non-class function types. +template +struct FunctionTraits { + /// The number of arguments to this function. + enum { num_args = sizeof...(Args) }; + + /// The result type of this function. + using result_t = ReturnType; + + /// The type of an argument to this function. + template + using arg_t = typename std::tuple_element>::type; +}; } // end namespace mlir // Allow tuples to be usable as DenseMap keys. diff --git a/mlir/unittests/ADT/CMakeLists.txt b/mlir/unittests/ADT/CMakeLists.txt new file mode 100644 index 0000000..cb12262 --- /dev/null +++ b/mlir/unittests/ADT/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_unittest(MLIRADTTests + TypeSwitchTest.cpp +) + +target_link_libraries(MLIRADTTests PRIVATE MLIRSupport LLVMSupport) diff --git a/mlir/unittests/ADT/TypeSwitchTest.cpp b/mlir/unittests/ADT/TypeSwitchTest.cpp new file mode 100644 index 0000000..b6a78de --- /dev/null +++ b/mlir/unittests/ADT/TypeSwitchTest.cpp @@ -0,0 +1,97 @@ +//===- TypeSwitchTest.cpp - TypeSwitch unit tests -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/ADT/TypeSwitch.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +/// Utility classes to setup casting functionality. +struct Base { + enum Kind { DerivedA, DerivedB, DerivedC, DerivedD, DerivedE }; + Kind kind; +}; +template struct DerivedImpl : Base { + DerivedImpl() : Base{DerivedKind} {} + static bool classof(const Base *base) { return base->kind == DerivedKind; } +}; +struct DerivedA : public DerivedImpl {}; +struct DerivedB : public DerivedImpl {}; +struct DerivedC : public DerivedImpl {}; +struct DerivedD : public DerivedImpl {}; +struct DerivedE : public DerivedImpl {}; +} // end anonymous namespace + +TEST(StringSwitchTest, CaseResult) { + auto translate = [](auto value) { + return TypeSwitch(&value) + .Case([](DerivedA *) { return 0; }) + .Case([](DerivedB *) { return 1; }) + .Case([](DerivedC *) { return 2; }) + .Default([](Base *) { return -1; }); + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(1, translate(DerivedB())); + EXPECT_EQ(2, translate(DerivedC())); + EXPECT_EQ(-1, translate(DerivedD())); +} + +TEST(StringSwitchTest, CasesResult) { + auto translate = [](auto value) { + return TypeSwitch(&value) + .Case([](auto *) { return 0; }) + .Case([](DerivedC *) { return 1; }) + .Default([](Base *) { return -1; }); + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(0, translate(DerivedB())); + EXPECT_EQ(1, translate(DerivedC())); + EXPECT_EQ(0, translate(DerivedD())); + EXPECT_EQ(-1, translate(DerivedE())); +} + +TEST(StringSwitchTest, CaseVoid) { + auto translate = [](auto value) { + int result = -2; + TypeSwitch(&value) + .Case([&](DerivedA *) { result = 0; }) + .Case([&](DerivedB *) { result = 1; }) + .Case([&](DerivedC *) { result = 2; }) + .Default([&](Base *) { result = -1; }); + return result; + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(1, translate(DerivedB())); + EXPECT_EQ(2, translate(DerivedC())); + EXPECT_EQ(-1, translate(DerivedD())); +} + +TEST(StringSwitchTest, CasesVoid) { + auto translate = [](auto value) { + int result = -1; + TypeSwitch(&value) + .Case([&](auto *) { result = 0; }) + .Case([&](DerivedC *) { result = 1; }); + return result; + }; + EXPECT_EQ(0, translate(DerivedA())); + EXPECT_EQ(0, translate(DerivedB())); + EXPECT_EQ(1, translate(DerivedC())); + EXPECT_EQ(0, translate(DerivedD())); + EXPECT_EQ(-1, translate(DerivedE())); +} diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index e80cc91..79a5297 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname) add_unittest(MLIRUnitTests ${test_dirname} ${ARGN}) endfunction() +add_subdirectory(ADT) add_subdirectory(Dialect) add_subdirectory(IR) add_subdirectory(Pass)