From 42b60d34fc3dfbad2b26568cf1cd903685df3a3e Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 8 Sep 2019 23:39:34 -0700 Subject: [PATCH] Add `parseGenericOperation()` to the OpAsmParser This method parses an operation in its generic form, from the current parser state. This is the symmetric of OpAsmPrinter::printGenericOp(). An immediate use case is illustrated in the test dialect, where an operation wraps another one in its region and makes use of a single-line pretty-print form. PiperOrigin-RevId: 267930869 --- mlir/include/mlir/IR/OpImplementation.h | 8 +++++ mlir/lib/Parser/Parser.cpp | 17 +++++++++++ mlir/test/IR/wrapping_op.mlir | 14 +++++++++ mlir/test/lib/TestDialect/TestDialect.cpp | 36 +++++++++++++++++++++++ mlir/test/lib/TestDialect/TestOps.td | 14 +++++++++ 5 files changed, 89 insertions(+) create mode 100644 mlir/test/IR/wrapping_op.mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index c97272b66e08..1b89cfc178d4 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -213,6 +213,14 @@ public: // these to be chained together into a linear sequence of || expressions in // many cases. + /// Parse an operation in its generic form. + /// The parsed operation is parsed in the current context and inserted in the + /// provided block and insertion point. The results produced by this operation + /// aren't mapped to any named value in the parser. Returns nullptr on + /// failure. + virtual Operation *parseGenericOperation(Block *insertBlock, + Block::iterator insertPt) = 0; + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index a6ccd76e8065..a22b1d865503 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2606,6 +2606,11 @@ public: /// Parse an operation instance that is in the generic form. Operation *parseGenericOperation(); + /// Parse an operation instance that is in the generic form and insert it at + /// the provided insertion point. + Operation *parseGenericOperation(Block *insertBlock, + Block::iterator insertPt); + /// Parse an operation instance that is in the op-defined custom form. Operation *parseCustomOperation(); @@ -3255,6 +3260,13 @@ Operation *OperationParser::parseGenericOperation() { return opBuilder.createOperation(result); } +Operation *OperationParser::parseGenericOperation(Block *insertBlock, + Block::iterator insertPt) { + OpBuilder::InsertionGuard restoreInsertionPoint(opBuilder); + opBuilder.setInsertionPoint(insertBlock, insertPt); + return parseGenericOperation(); +} + namespace { class CustomOpAsmParser : public OpAsmParser { public: @@ -3270,6 +3282,11 @@ public: return success(); } + Operation *parseGenericOperation(Block *insertBlock, + Block::iterator insertPt) final { + return parser.parseGenericOperation(insertBlock, insertPt); + } + //===--------------------------------------------------------------------===// // Utilities //===--------------------------------------------------------------------===// diff --git a/mlir/test/IR/wrapping_op.mlir b/mlir/test/IR/wrapping_op.mlir new file mode 100644 index 000000000000..f7b838c2e28a --- /dev/null +++ b/mlir/test/IR/wrapping_op.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt -mlir-print-op-generic %s | FileCheck %s --check-prefix=CHECK-GENERIC + +// CHECK-LABEL: func @wrapping_op +// CHECK-GENERIC-LABEL: func @wrapping_op +func @wrapping_op(%arg0 : i32, %arg1 : f32) -> (i3, i2, i1) { +// CHECK: %0:3 = test.wrapping_region wraps "some.op"(%arg1, %arg0) {test.attr = "attr"} : (f32, i32) -> (i1, i2, i3) +// CHECK-GENERIC: "test.wrapping_region"() ( { +// CHECK-GENERIC: %[[NESTED_RES:.*]]:3 = "some.op"(%arg1, %arg0) {test.attr = "attr"} : (f32, i32) -> (i1, i2, i3) +// CHECK-GENERIC: "test.return"(%[[NESTED_RES]]#0, %[[NESTED_RES]]#1, %[[NESTED_RES]]#2) : (i1, i2, i3) -> () +// CHECK-GENERIC: }) : () -> (i1, i2, i3) + %res:3 = test.wrapping_region wraps "some.op"(%arg1, %arg0) { test.attr = "attr" } : (f32, i32) -> (i1, i2, i3) + return %res#2, %res#1, %res#0 : i3, i2, i1 +} \ No newline at end of file diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index ea550b262763..21240a3cd8db 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -122,6 +122,42 @@ static void print(OpAsmPrinter *p, IsolatedRegionOp op) { p->printRegion(op.region(), /*printEntryBlockArgs=*/false); } +//===----------------------------------------------------------------------===// +// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. +//===----------------------------------------------------------------------===// + +static ParseResult parseWrappingRegionOp(OpAsmParser *parser, + OperationState *result) { + if (parser->parseOptionalKeyword("wraps")) + return failure(); + + // Parse the wrapped op in a region + Region &body = *result->addRegion(); + body.push_back(new Block); + Block &block = body.back(); + Operation *wrapped_op = parser->parseGenericOperation(&block, block.begin()); + if (!wrapped_op) + return failure(); + + // Create a return terminator in the inner region, pass as operand to the + // terminator the returned values from the wrapped operation. + SmallVector return_operands(wrapped_op->getResults()); + OpBuilder builder(parser->getBuilder().getContext()); + builder.setInsertionPointToEnd(&block); + builder.create(result->location, return_operands); + + // Get the results type for the wrapping op from the terminator operands. + Operation &return_op = body.back().back(); + result->types.append(return_op.operand_type_begin(), + return_op.operand_type_end()); + return success(); +} + +static void print(OpAsmPrinter *p, WrappingRegionOp op) { + *p << op.getOperationName() << " wraps "; + p->printGenericOp(&op.region().front().front()); +} + //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index da28f0cd4b4d..1fda7d41356f 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -757,6 +757,20 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> { let printer = [{ return ::print(p, *this); }]; } +def WrappingRegionOp : TEST_Op<"wrapping_region", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let summary = "wrapping region operation"; + let description = [{ + Test op wrapping another op in a region, to test calling + parseGenericOperation from the custom parser. + }]; + + let results = (outs Variadic:$outputs); + let regions = (region SizedRegion<1>:$region); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + def PolyForOp : TEST_Op<"polyfor"> { let summary = "polyfor operation"; -- 2.34.1