[mlir] Add a new fold API using Generic Adaptors
authorMarkus Böck <markus.boeck02@gmail.com>
Sun, 25 Dec 2022 18:29:31 +0000 (19:29 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Wed, 11 Jan 2023 13:32:21 +0000 (14:32 +0100)
This is part of the RFC for a better fold API: https://discourse.llvm.org/t/rfc-a-better-fold-api-using-more-generic-adaptors/67374

This patch implements the required foldHook changes and the TableGen machinery for generating `fold` method signatures using `FoldAdaptor` for ops, based on the value of `useFoldAPI` of the dialect. It may be one of 2 values, with convenient named constants to create a quasi enum. The new `fold` method will then be generated if `kEmitFoldAdaptorFolder` is used.

Since the new `FoldAdaptor` approach is strictly better than the old signature, part of this patch updates the documentation and all example to encourage use of the new `fold` signature.
Included are also tests exercising the new API, ensuring proper construction of the `FoldAdaptor` and proper generation by TableGen.

Differential Revision: https://reviews.llvm.org/D140886

17 files changed:
mlir/docs/Canonicalization.md
mlir/docs/DefiningDialects/_index.md
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/IR/DialectBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/TableGen/Dialect.h
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Dialect.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/IR/test-fold-adaptor.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/has-fold-invalid-values.td [new file with mode: 0644]
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index d1aed54..d1cba57 100644 (file)
@@ -156,7 +156,7 @@ If the operation has a single result the following will be generated:
 ///     of the operation. The caller will remove the operation and use that
 ///     result instead.
 ///
-OpFoldResult MyOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MyOp::fold(FoldAdaptor adaptor) {
   ...
 }
 ```
@@ -178,19 +178,19 @@ Otherwise, the following is generated:
 ///     the operation and use those results instead.
 ///
 /// Note that this mechanism cannot be used to remove 0-result operations.
-LogicalResult MyOp::fold(ArrayRef<Attribute> operands,
+LogicalResult MyOp::fold(FoldAdaptor adaptor,
                          SmallVectorImpl<OpFoldResult> &results) {
   ...
 }
 ```
 
-In the above, for each method an `ArrayRef<Attribute>` is provided that
-corresponds to the constant attribute value of each of the operands. These
+In the above, for each method a `FoldAdaptor` is provided with getters for
+each of the operands, returning the corresponding constant attribute. These
 operands are those that implement the `ConstantLike` trait. If any of the
 operands are non-constant, a null `Attribute` value is provided instead. For
 example, if MyOp provides three operands [`a`, `b`, `c`], but only `b` is
-constant then `operands` will be of the form [Attribute(), b-value,
-Attribute()].
+constant then `adaptor` will return Attribute() for `getA()` and `getC()`, 
+and b-value for `getB()`.
 
 Also above, is the use of `OpFoldResult`. This class represents the possible
 result of folding an operation result: either an SSA `Value`, or an
index 0f2b945..ca07c33 100644 (file)
@@ -255,6 +255,31 @@ LogicalResult MyDialect::verifyRegionResultAttribute(Operation *op, unsigned reg
                                                      unsigned argIndex, NamedAttribute attribute);
 ```
 
+#### `useFoldAPI`
+
+There are currently two possible values that are allowed to be assigned to this
+field:
+* `kEmitFoldAdaptorFolder` generates a `fold` method making use of the op's 
+  `FoldAdaptor` to allow access of operands via convenient getter.
+
+  Generated code example:
+  ```cpp
+  OpFoldResult fold(FoldAdaptor adaptor);
+  // or
+  LogicalResult fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult>& results);
+  ```
+* `kEmitRawAttributesFolder` generates the deprecated legacy `fold`
+  method, containing `ArrayRef<Attribute>` in the parameter list instead of
+  the op's `FoldAdaptor`. This API is scheduled for removal and should not be 
+  used by new dialects.
+
+  Generated code example:
+  ```cpp
+  OpFoldResult fold(ArrayRef<Attribute> operands);
+  // or
+  LogicalResult fold(ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult>& results);
+  ```
+
 ### Operation Interface Fallback
 
 Some dialects have an open ecosystem and don't register all of the possible operations. In such
index 6148b92..2114bf6 100644 (file)
@@ -458,16 +458,16 @@ method.
 
 ```c++
 /// Fold constants.
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return value(); }
 
 /// Fold struct constants.
-OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) {
   return value();
 }
 
 /// Fold simple struct access operations that access into a constant.
-OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
-  auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
+OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
+  auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
   if (!structAttr)
     return nullptr;
 
index 08671a7..504d316 100644 (file)
@@ -33,6 +33,8 @@ def Toy_Dialect : Dialect {
   // We set this bit to generate the declarations for the dialect's type parsing
   // and printing hooks.
   let useDefaultTypePrinterParser = 1;
+
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 // Base class for toy dialect operations. This operation inherits from the base
index 36ba04b..62b00d9 100644 (file)
@@ -24,18 +24,14 @@ namespace {
 } // namespace
 
 /// Fold constants.
-OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
-  return getValue();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 /// Fold struct constants.
-OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
-  return getValue();
-}
+OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 /// Fold simple struct access operations that access into a constant.
-OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
-  auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
+OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
+  auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
   if (!structAttr)
     return nullptr;
 
index ab14e34..043019e 100644 (file)
 // Dialect definitions
 //===----------------------------------------------------------------------===//
 
+// Generate 'fold' method with 'ArrayRef<Attribute>' parameter.
+// New code should prefer using 'kEmitFoldAdaptorFolder' and
+// consider 'kEmitRawAttributesFolder' deprecated and to be
+// removed in the future.
+defvar kEmitRawAttributesFolder = 0;
+// Generate 'fold' method with 'FoldAdaptor' parameter.
+defvar kEmitFoldAdaptorFolder = 1;
+
 class Dialect {
   // The name of the dialect.
   string name = ?;
@@ -85,6 +93,9 @@ class Dialect {
 
   // If this dialect can be extended at runtime with new operations or types.
   bit isExtensible = 0;
+
+  // Fold API to use for operations in this dialect.
+  int useFoldAPI = kEmitRawAttributesFolder;
 }
 
 #endif // DIALECTBASE_TD
index c185750..96b6e17 100644 (file)
@@ -1686,18 +1686,35 @@ public:
 private:
   /// Trait to check if T provides a 'fold' method for a single result op.
   template <typename T, typename... Args>
-  using has_single_result_fold =
+  using has_single_result_fold_t =
       decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
   template <typename T>
-  using detect_has_single_result_fold =
-      llvm::is_detected<has_single_result_fold, T>;
+  constexpr static bool has_single_result_fold_v =
+      llvm::is_detected<has_single_result_fold_t, T>::value;
   /// Trait to check if T provides a general 'fold' method.
   template <typename T, typename... Args>
-  using has_fold = decltype(std::declval<T>().fold(
+  using has_fold_t = decltype(std::declval<T>().fold(
       std::declval<ArrayRef<Attribute>>(),
       std::declval<SmallVectorImpl<OpFoldResult> &>()));
   template <typename T>
-  using detect_has_fold = llvm::is_detected<has_fold, T>;
+  constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
+  /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
+  /// single result op.
+  template <typename T, typename... Args>
+  using has_fold_adaptor_single_result_fold_t =
+      decltype(std::declval<T>().fold(std::declval<typename T::FoldAdaptor>()));
+  template <class T>
+  constexpr static bool has_fold_adaptor_single_result_v =
+      llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
+  /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
+  template <typename T, typename... Args>
+  using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
+      std::declval<typename T::FoldAdaptor>(),
+      std::declval<SmallVectorImpl<OpFoldResult> &>()));
+  template <class T>
+  constexpr static bool has_fold_adaptor_v =
+      llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
+
   /// Trait to check if T provides a 'print' method.
   template <typename T, typename... Args>
   using has_print =
@@ -1746,13 +1763,14 @@ private:
     // If the operation is single result and defines a `fold` method.
     if constexpr (llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
                                   Traits<ConcreteType>...>::value &&
-                  detect_has_single_result_fold<ConcreteType>::value)
+                  (has_single_result_fold_v<ConcreteType> ||
+                   has_fold_adaptor_single_result_v<ConcreteType>))
       return [](Operation *op, ArrayRef<Attribute> operands,
                 SmallVectorImpl<OpFoldResult> &results) {
         return foldSingleResultHook<ConcreteType>(op, operands, results);
       };
     // The operation is not single result and defines a `fold` method.
-    if constexpr (detect_has_fold<ConcreteType>::value)
+    if constexpr (has_fold_v<ConcreteType> || has_fold_adaptor_v<ConcreteType>)
       return [](Operation *op, ArrayRef<Attribute> operands,
                 SmallVectorImpl<OpFoldResult> &results) {
         return foldHook<ConcreteType>(op, operands, results);
@@ -1771,7 +1789,12 @@ private:
   static LogicalResult
   foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
                        SmallVectorImpl<OpFoldResult> &results) {
-    OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
+    OpFoldResult result;
+    if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
+      result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
+          operands, op->getAttrDictionary(), op->getRegions()));
+    else
+      result = cast<ConcreteOpT>(op).fold(operands);
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
@@ -1788,7 +1811,15 @@ private:
   template <typename ConcreteOpT>
   static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
                                 SmallVectorImpl<OpFoldResult> &results) {
-    LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
+    auto result = LogicalResult::failure();
+    if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
+      result = cast<ConcreteOpT>(op).fold(
+          typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
+                                            op->getRegions()),
+          results);
+    } else {
+      result = cast<ConcreteOpT>(op).fold(operands, results);
+    }
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
index d85342c..8fd519f 100644 (file)
@@ -86,6 +86,15 @@ public:
   /// operations or types.
   bool isExtensible() const;
 
+  enum class FolderAPI {
+    RawAttributes = 0, /// fold method with ArrayRef<Attribute>.
+    FolderAdaptor = 1, /// fold method with the operation's FoldAdaptor.
+  };
+
+  /// Returns the folder API that should be emitted for operations in this
+  /// dialect.
+  FolderAPI getFolderAPI() const;
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;
index 99c3eed..f4a475d 100644 (file)
@@ -314,6 +314,8 @@ public:
   /// Returns the remove name for the accessor of `name`.
   std::string getRemoverName(StringRef name) const;
 
+  bool hasFolder() const;
+
 private:
   /// Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
index 2bbca96..8d6b047 100644 (file)
@@ -102,6 +102,16 @@ bool Dialect::isExtensible() const {
   return def->getValueAsBit("isExtensible");
 }
 
+Dialect::FolderAPI Dialect::getFolderAPI() const {
+  int64_t value = def->getValueAsInt("useFoldAPI");
+  if (value < static_cast<int64_t>(FolderAPI::RawAttributes) ||
+      value > static_cast<int64_t>(FolderAPI::FolderAdaptor))
+    llvm::PrintFatalError(def->getLoc(),
+                          "Invalid value for dialect field `useFoldAPI`");
+
+  return static_cast<FolderAPI>(value);
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }
index 4417705..150c385 100644 (file)
@@ -745,3 +745,5 @@ std::string Operator::getSetterName(StringRef name) const {
 std::string Operator::getRemoverName(StringRef name) const {
   return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
 }
+
+bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); }
diff --git a/mlir/test/IR/test-fold-adaptor.mlir b/mlir/test/IR/test-fold-adaptor.mlir
new file mode 100644 (file)
index 0000000..7815e72
--- /dev/null
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+func.func @test() -> i32 {
+  %c5 = "test.constant"() {value = 5 : i32} : () -> i32
+  %c1 = "test.constant"() {value = 1 : i32} : () -> i32
+  %c2 = "test.constant"() {value = 2 : i32} : () -> i32
+  %c3 = "test.constant"() {value = 3 : i32} : () -> i32
+  %res = test.fold_with_fold_adaptor %c5, [ %c1, %c2], { (%c3), (%c3) } {
+    %c0 = "test.constant"() {value = 0 : i32} : () -> i32
+  }
+  return %res : i32
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32}
+// CHECK-NEXT: return %[[C]]
index e710f03..0f48bfc 100644 (file)
@@ -33,6 +33,8 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 
+#include <numeric>
+
 // Include this before the using namespace lines below to
 // test that we don't have namespace dependencies.
 #include "TestOpsDialect.cpp.inc"
@@ -1126,6 +1128,25 @@ OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
   return getOperand();
 }
 
+OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
+  int64_t sum = 0;
+  if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
+    sum += value.getValue().getSExtValue();
+
+  for (Attribute attr : adaptor.getVariadic())
+    if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+      sum += 2 * value.getValue().getSExtValue();
+
+  for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
+    for (Attribute attr : attrs)
+      if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+        sum += 3 * value.getValue().getSExtValue();
+
+  sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
+
+  return IntegerAttr::get(getType(), sum);
+}
+
 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
     MLIRContext *, std::optional<Location> location, ValueRange operands,
     DictionaryAttr attributes, RegionRange regions,
index d027db4..7d66136 100644 (file)
@@ -1297,6 +1297,31 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
   }];
 }
 
+def TestOpFoldWithFoldAdaptor
+  : TEST_Op<"fold_with_fold_adaptor",
+      [AttrSizedOperandSegments, NoTerminator]> {
+  let arguments = (ins
+    I32:$op,
+    DenseI32ArrayAttr:$attr,
+    Variadic<I32>:$variadic,
+    VariadicOfVariadic<I32, "attr">:$var_of_var
+  );
+
+  let results = (outs I32:$res);
+
+  let regions = (region AnyRegion:$body);
+
+  let assemblyFormat = [{
+    $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
+  }];
+
+  let hasFolder = 0;
+
+  let extraClassDeclaration = [{
+    ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
+  }];
+}
+
 // An op that always fold itself.
 def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
   let arguments = (ins AnyType:$op);
diff --git a/mlir/test/mlir-tblgen/has-fold-invalid-values.td b/mlir/test/mlir-tblgen/has-fold-invalid-values.td
new file mode 100644 (file)
index 0000000..09149a5
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+  let cppNamespace = "NS";
+  let useFoldAPI = 3;
+}
+
+def InvalidValue_Op : Op<Test_Dialect, "invalid_op"> {
+  let hasFolder = 1;
+}
+
+// CHECK: Invalid value for dialect field `useFoldAPI`
index 3a5af12..6f9d24d 100644 (file)
@@ -317,6 +317,29 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
 // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
+def TestWithNewFold_Dialect : Dialect {
+  let name = "test";
+  let cppNamespace = "::mlir::testWithFold";
+  let useFoldAPI = kEmitFoldAdaptorFolder;
+}
+
+def NS_MOp : Op<TestWithNewFold_Dialect, "op_with_single_result_and_fold_adaptor_fold", []> {
+  let results = (outs AnyType:$res);
+
+  let hasFolder = 1;
+}
+
+// CHECK-LABEL: class MOp :
+// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
+
+def NS_NOp : Op<TestWithNewFold_Dialect, "op_with_multiple_results_and_fold_adaptor_fold", []> {
+  let results = (outs AnyType:$res1, AnyType:$res2);
+
+  let hasFolder = 1;
+}
+
+// CHECK-LABEL: class NOp :
+// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
 
 // Test that type defs have the proper namespaces when used as a constraint.
 // ---
index 3b45bb5..2483378 100644 (file)
@@ -2326,25 +2326,29 @@ void OpEmitter::genCanonicalizerDecls() {
 }
 
 void OpEmitter::genFolderDecls() {
+  if (!op.hasFolder())
+    return;
+
+  Dialect::FolderAPI folderApi = op.getDialect().getFolderAPI();
+  SmallVector<MethodParameter> paramList;
+  if (folderApi == Dialect::FolderAPI::RawAttributes)
+    paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
+  else
+    paramList.emplace_back("FoldAdaptor", "adaptor");
+
+  StringRef retType;
   bool hasSingleResult =
       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
-
-  if (def.getValueAsBit("hasFolder")) {
-    if (hasSingleResult) {
-      auto *m = opClass.declareMethod(
-          "::mlir::OpFoldResult", "fold",
-          MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
-      ERROR_IF_PRUNED(m, "operands", op);
-    } else {
-      SmallVector<MethodParameter> paramList;
-      paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
-      paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
-                             "results");
-      auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
-                                      std::move(paramList));
-      ERROR_IF_PRUNED(m, "fold", op);
-    }
+  if (hasSingleResult) {
+    retType = "::mlir::OpFoldResult";
+  } else {
+    paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
+                           "results");
+    retType = "::mlir::LogicalResult";
   }
+
+  auto *m = opClass.declareMethod(retType, "fold", std::move(paramList));
+  ERROR_IF_PRUNED(m, "fold", op);
 }
 
 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {