[mlir][transform] Add C++ builder to SequenceOp
authorMatthias Springer <springerm@google.com>
Thu, 17 Nov 2022 14:54:52 +0000 (15:54 +0100)
committerMatthias Springer <springerm@google.com>
Thu, 17 Nov 2022 14:58:13 +0000 (15:58 +0100)
This change adds a builder that populates the body of a SequenceOp. This is useful for constructing SequenceOps from C++.

Differential Revision: https://reviews.llvm.org/D137710

mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp

index b9f4eed..822c24d 100644 (file)
@@ -21,6 +21,10 @@ namespace mlir {
 namespace transform {
 enum class FailurePropagationMode : uint32_t;
 class FailurePropagationModeAttr;
+
+/// A builder function that populates the body of a SequenceOp.
+using SequenceBodyBuilderFn = ::llvm::function_ref<void(
+    ::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)>;
 } // namespace transform
 } // namespace mlir
 
index 42f8d5c..d81bea6 100644 (file)
@@ -389,6 +389,10 @@ def SequenceOp : TransformDialectOp<"sequence",
     IR, typically the root operation of the pass interpreting the transform
     dialect. Operand omission is only allowed for sequences not contained in
     another sequence.
+
+    The body of the sequence terminates with an implicit or explicit
+    `transform.yield` op. The operands of the terminator are returned as the
+    results of the sequence op.
   }];
 
   let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
@@ -400,6 +404,20 @@ def SequenceOp : TransformDialectOp<"sequence",
     "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
     "$failure_propagation_mode `)` attr-dict-with-keyword regions";
 
+  let builders = [
+    // Build a sequence with a root.
+    OpBuilder<(ins
+        "::mlir::TypeRange":$resultTypes,
+        "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+        "::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>,
+
+    // Build a sequence without a root but a certain bbArg type.
+    OpBuilder<(ins
+        "::mlir::TypeRange":$resultTypes,
+        "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+        "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>
+  ];
+
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }
index af759d9..76e0c89 100644 (file)
@@ -765,6 +765,39 @@ void transform::SequenceOp::getRegionInvocationBounds(
   bounds.emplace_back(1, 1);
 }
 
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+                                  TypeRange resultTypes,
+                                  FailurePropagationMode failurePropagationMode,
+                                  Value root,
+                                  SequenceBodyBuilderFn bodyBuilder) {
+  build(builder, state, resultTypes, failurePropagationMode, root);
+  Region *region = state.regions.back().get();
+  auto bbArgType = root.getType();
+  Block *bodyBlock = builder.createBlock(
+      region, region->begin(), TypeRange{bbArgType}, {state.location});
+
+  // Populate body.
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(bodyBlock);
+  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+}
+
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+                                  TypeRange resultTypes,
+                                  FailurePropagationMode failurePropagationMode,
+                                  Type bbArgType,
+                                  SequenceBodyBuilderFn bodyBuilder) {
+  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
+  Region *region = state.regions.back().get();
+  Block *bodyBlock = builder.createBlock(
+      region, region->begin(), TypeRange{bbArgType}, {state.location});
+
+  // Populate body.
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(bodyBlock);
+  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+}
+
 //===----------------------------------------------------------------------===//
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//