From 225f11cff7fb1983cd849fbd253c459a804ce525 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Markus=20B=C3=B6ck?= Date: Thu, 23 Feb 2023 00:06:21 +0100 Subject: [PATCH] [mlir] Partially revert removal of old `fold` method Mehdi noted in https://reviews.llvm.org/D144391 that given the low cost of keeping the old `fold` method signature working and the difficulty of writing a `FoldAdaptor` oneself, it'd be nice to keep the support for the sake of Ops written manually in C++. This patch therefore partially reverts the removal of the old `fold` method by still allowing the old signature to be used. The active use of it is still discouraged and ODS will always generate the new method using `FoldAdaptor`s. I'd also like to note that the previous ought to have broken some manually defined `fold` methods in-tree that are defined here: https://github.com/llvm/llvm-project/blob/23bcd6b86271f1c219a69183a5d90654faca64b8/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h#L245 It seems like these are not part of the regressions tests however... Differential Revision: https://reviews.llvm.org/D144591 --- mlir/include/mlir/IR/OpDefinition.h | 47 +++++++++++++++++++++++------- mlir/test/IR/test-manual-cpp-fold.mlir | 11 +++++++ mlir/test/lib/Dialect/Test/TestDialect.cpp | 9 ++++++ mlir/test/lib/Dialect/Test/TestDialect.h | 17 +++++++++++ 4 files changed, 73 insertions(+), 11 deletions(-) create mode 100644 mlir/test/IR/test-manual-cpp-fold.mlir diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index fe2bdd5..f7d8436 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1688,17 +1688,33 @@ private: /// Trait to check if T provides a 'fold' method for a single result op. template using has_single_result_fold_t = - decltype(std::declval().fold(std::declval())); + decltype(std::declval().fold(std::declval>())); template constexpr static bool has_single_result_fold_v = llvm::is_detected::value; /// Trait to check if T provides a general 'fold' method. template using has_fold_t = decltype(std::declval().fold( - std::declval(), + std::declval>(), std::declval &>())); template constexpr static bool has_fold_v = llvm::is_detected::value; + /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a + /// single result op. + template + using has_fold_adaptor_single_result_fold_t = + decltype(std::declval().fold(std::declval())); + template + constexpr static bool has_fold_adaptor_single_result_v = + llvm::is_detected::value; + /// Trait to check if T provides a general 'fold' method with a FoldAdaptor. + template + using has_fold_adaptor_fold_t = decltype(std::declval().fold( + std::declval(), + std::declval &>())); + template + constexpr static bool has_fold_adaptor_v = + llvm::is_detected::value; /// Trait to check if T provides a 'print' method. template @@ -1748,13 +1764,14 @@ private: // If the operation is single result and defines a `fold` method. if constexpr (llvm::is_one_of, Traits...>::value && - has_single_result_fold_v) + (has_single_result_fold_v || + has_fold_adaptor_single_result_v)) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldSingleResultHook(op, operands, results); }; // The operation is not single result and defines a `fold` method. - if constexpr (has_fold_v) + if constexpr (has_fold_v || has_fold_adaptor_v) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldHook(op, operands, results); @@ -1773,9 +1790,12 @@ private: static LogicalResult foldSingleResultHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - OpFoldResult result = - cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), op->getRegions())); + OpFoldResult result; + if constexpr (has_fold_adaptor_single_result_v) + result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( + operands, op->getAttrDictionary(), op->getRegions())); + else + result = cast(op).fold(operands); // If the fold failed or was in-place, try to fold the traits of the // operation. @@ -1792,10 +1812,15 @@ private: template static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - LogicalResult result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), - op->getRegions()), - results); + auto result = LogicalResult::failure(); + if constexpr (has_fold_adaptor_v) { + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), + op->getRegions()), + results); + } else { + result = cast(op).fold(operands, results); + } // If the fold failed or was in-place, try to fold the traits of the // operation. diff --git a/mlir/test/IR/test-manual-cpp-fold.mlir b/mlir/test/IR/test-manual-cpp-fold.mlir new file mode 100644 index 0000000..592b949 --- /dev/null +++ b/mlir/test/IR/test-manual-cpp-fold.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s + +func.func @test() -> i32 { + %c5 = "test.constant"() {value = 5 : i32} : () -> i32 + %res = "test.manual_cpp_op_with_fold"(%c5) : (i32) -> i32 + return %res : i32 +} + +// CHECK-LABEL: func.func @test +// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 5 : i32} +// CHECK-NEXT: return %[[C]] diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index e62d5a8..dc5f629 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -358,6 +358,7 @@ void TestDialect::initialize() { #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + addOperations(); registerDynamicOp(getDynamicGenericOp(this)); registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); @@ -1634,6 +1635,14 @@ void TestReflectBoundsOp::inferResultRanges( setResultRanges(getResult(), range); } +OpFoldResult ManualCppOpWithFold::fold(ArrayRef attributes) { + // Just a simple fold for testing purposes that reads an operands constant + // value and returns it. + if (!attributes.empty()) + return attributes.front(); + return nullptr; +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index ceb9dc6..ad3ef2a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -58,6 +58,23 @@ class RewritePatternSet; #include "TestOps.h.inc" namespace test { + +// Op deliberately defined in C++ code rather than ODS to test that C++ +// Ops can still use the old `fold` method. +class ManualCppOpWithFold + : public mlir::Op { +public: + using Op::Op; + + static llvm::StringRef getOperationName() { + return "test.manual_cpp_op_with_fold"; + } + + static llvm::ArrayRef getAttributeNames() { return {}; } + + mlir::OpFoldResult fold(llvm::ArrayRef attributes); +}; + void registerTestDialect(::mlir::DialectRegistry ®istry); void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns); } // namespace test -- 2.7.4