From 00c95b19d7963dae4e8bdee66a9880d44761cffe Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 17 Nov 2022 15:54:52 +0100 Subject: [PATCH] [mlir][transform] Add C++ builder to SequenceOp This change adds a builder that populates the body of a SequenceOp. This is useful for constructing SequenceOps from C++. Differential Revision: https://reviews.llvm.org/D137710 --- .../mlir/Dialect/Transform/IR/TransformOps.h | 4 +++ .../mlir/Dialect/Transform/IR/TransformOps.td | 18 ++++++++++++ mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 33 ++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h index b9f4eed..822c24d 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -21,6 +21,10 @@ namespace mlir { namespace transform { enum class FailurePropagationMode : uint32_t; class FailurePropagationModeAttr; + +/// A builder function that populates the body of a SequenceOp. +using SequenceBodyBuilderFn = ::llvm::function_ref; } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 42f8d5c..d81bea6 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -389,6 +389,10 @@ def SequenceOp : TransformDialectOp<"sequence", IR, typically the root operation of the pass interpreting the transform dialect. Operand omission is only allowed for sequences not contained in another sequence. + + The body of the sequence terminates with an implicit or explicit + `transform.yield` op. The operands of the terminator are returned as the + results of the sequence op. }]; let arguments = (ins FailurePropagationMode:$failure_propagation_mode, @@ -400,6 +404,20 @@ def SequenceOp : TransformDialectOp<"sequence", "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` " "$failure_propagation_mode `)` attr-dict-with-keyword regions"; + let builders = [ + // Build a sequence with a root. + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, + "::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>, + + // Build a sequence without a root but a certain bbArg type. + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, + "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)> + ]; + let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. static StringRef getDefaultDialect() { return "transform"; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index af759d9..76e0c89 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -765,6 +765,39 @@ void transform::SequenceOp::getRegionInvocationBounds( bounds.emplace_back(1, 1); } +void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + FailurePropagationMode failurePropagationMode, + Value root, + SequenceBodyBuilderFn bodyBuilder) { + build(builder, state, resultTypes, failurePropagationMode, root); + Region *region = state.regions.back().get(); + auto bbArgType = root.getType(); + Block *bodyBlock = builder.createBlock( + region, region->begin(), TypeRange{bbArgType}, {state.location}); + + // Populate body. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); +} + +void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + FailurePropagationMode failurePropagationMode, + Type bbArgType, + SequenceBodyBuilderFn bodyBuilder) { + build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value()); + Region *region = state.regions.back().get(); + Block *bodyBlock = builder.createBlock( + region, region->begin(), TypeRange{bbArgType}, {state.location}); + + // Populate body. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); +} + //===----------------------------------------------------------------------===// // WithPDLPatternsOp //===----------------------------------------------------------------------===// -- 2.7.4