From d7f0083dcae45e6bf774af23533a2d5e18aaf253 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 4 Feb 2022 20:47:01 -0800 Subject: [PATCH] [mlir:ODS] Deprecate Op parser/printer fields in favor of a new hasCustomAssemblyFormat field Currently if an operation wants a C++ implemented parser/printer, it specifies inline code blocks. This is quite problematic for various reasons, e.g. it requires defining C++ inside of Tablegen which is discouraged when possible, but mainly because nearly all usages simply forward to static functions (e.g. `static void parseSomeOp(...)`) with users devising their own standards for how these are defined. This commit adds support for a `hasCustomAssemblyFormat` bit field that specifies if a C++ parser/printer is needed, and when set to 1 declares the parse/print methods for operations to override. For migration purposes, the existing behavior is untouched. Upstream usages will be replaced in a followup to keep this patch focused on the new implementation. Differential Revision: https://reviews.llvm.org/D119054 --- mlir/include/mlir/IR/OpBase.td | 16 +++- mlir/test/lib/Dialect/Test/TestDialect.cpp | 111 ++++++++++++++-------------- mlir/test/lib/Dialect/Test/TestOps.td | 36 +++------ mlir/test/mlir-tblgen/op-decl-and-defs.td | 3 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 28 ++++++- 5 files changed, 106 insertions(+), 88 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index e1af6d9..80a0949 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2442,14 +2442,24 @@ class Op props = []> { // provided. bit skipDefaultBuilders = 0; - // Custom parser. + // Custom parser and printer. + // NOTE: These fields are deprecated in favor of `assemblyFormat` or + // `hasCustomAssemblyFormat`, and are slated for deletion. code parser = ?; - - // Custom printer. code printer = ?; // Custom assembly format. + /// This field corresponds to a declarative description of the assembly format + /// for this operation. If populated, the `hasCustomAssemblyFormat` field is + /// ignored. string assemblyFormat = ?; + /// This field indicates that the operation has a custom assembly format + /// implemented in C++. When set to `1` a `parse` and `print` method are generated + /// on the operation class. The operation should implement these methods to + /// support the custom format of the operation. The methods have the form: + /// * ParseResult parse(OpAsmParser &parser, OperationState &result) + /// * void print(OpAsmPrinter &p) + bit hasCustomAssemblyFormat = 0; // A bit indicating if the operation has additional invariants that need to // verified (aside from those verified by other ODS constructs). If set to `1`, diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index e23173d..623d512 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -577,8 +577,8 @@ static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// -static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, + OperationState &result) { OpAsmParser::OperandType argInfo; Type argType = parser.getBuilder().getIndexType(); @@ -593,12 +593,12 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, /*enableNameShadowing=*/true); } -static void print(OpAsmPrinter &p, IsolatedRegionOp op) { +void IsolatedRegionOp::print(OpAsmPrinter &p) { p << "test.isolated_region "; - p.printOperand(op.getOperand()); - p.shadowRegionArgs(op.getRegion(), op.getOperand()); + p.printOperand(getOperand()); + p.shadowRegionArgs(getRegion(), getOperand()); p << ' '; - p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// @@ -613,16 +613,15 @@ RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { // Test GraphRegionOp //===----------------------------------------------------------------------===// -static ParseResult parseGraphRegionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } -static void print(OpAsmPrinter &p, GraphRegionOp op) { +void GraphRegionOp::print(OpAsmPrinter &p) { p << "test.graph_region "; - p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } RegionKind GraphRegionOp::getRegionKind(unsigned index) { @@ -633,24 +632,23 @@ RegionKind GraphRegionOp::getRegionKind(unsigned index) { // Test AffineScopeOp //===----------------------------------------------------------------------===// -static ParseResult parseAffineScopeOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } -static void print(OpAsmPrinter &p, AffineScopeOp op) { +void AffineScopeOp::print(OpAsmPrinter &p) { p << "test.affine_scope "; - p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// -static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, + OperationState &result) { if (parser.parseOptionalColon()) return success(); uint64_t numResults; @@ -663,13 +661,13 @@ static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { - if (unsigned numResults = op->getNumResults()) +void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { + if (unsigned numResults = getNumResults()) p << " : " << numResults; } -static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, + OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); @@ -677,15 +675,13 @@ static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { - p << " " << op.getKeyword(); -} +void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } //===----------------------------------------------------------------------===// // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. -static ParseResult parseWrappingRegionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult WrappingRegionOp::parse(OpAsmParser &parser, + OperationState &result) { if (parser.parseKeyword("wraps")) return failure(); @@ -715,9 +711,9 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, WrappingRegionOp op) { +void WrappingRegionOp::print(OpAsmPrinter &p) { p << " wraps "; - p.printGenericOp(&op.getRegion().front().front()); + p.printGenericOp(&getRegion().front().front()); } //===----------------------------------------------------------------------===// @@ -726,8 +722,8 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) { // parseCustomOperationName //===----------------------------------------------------------------------===// -static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, + OperationState &result) { SMLoc loc = parser.getCurrentLocation(); Location currLocation = parser.getEncodedSourceLoc(loc); @@ -799,11 +795,11 @@ static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { +void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { p << ' '; - p.printOperands(op.getOperands()); + p.printOperands(getOperands()); - Operation &innerOp = op.getRegion().front().front(); + Operation &innerOp = getRegion().front().front(); // Assuming that region has a single non-terminator inner-op, if the inner-op // meets some criteria (which in this case is a simple one based on the name // of inner-op), then we can print the entire region in a succinct way. @@ -813,19 +809,19 @@ static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) { p << " start special.op end"; } else { p << " ("; - p.printRegion(op.getRegion()); + p.printRegion(getRegion()); p << ")"; } p << " : "; - p.printFunctionalType(op); + p.printFunctionalType(*this); } //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// -static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { +ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector ivsInfo; // Parse list of region arguments without a delimiter. if (parser.parseRegionArgumentList(ivsInfo)) @@ -838,6 +834,8 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { return parser.parseRegion(*body, ivsInfo, argTypes); } +void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } + void PolyForOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { auto arrayAttr = getOperation()->getAttrOfType("arg_names"); @@ -1044,8 +1042,8 @@ void SideEffectOp::getEffects( //===----------------------------------------------------------------------===// // This op has fancy handling of its SSA result name. -static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, - OperationState &result) { +ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, + OperationState &result) { // Add the result types. for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) result.addTypes(parser.getBuilder().getIntegerType(32)); @@ -1081,19 +1079,19 @@ static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { +void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { // Note that we only need to print the "name" attribute if the asmprinter // result name disagrees with it. This can happen in strange cases, e.g. // when there are conflicts. - bool namesDisagree = op.getNames().size() != op.getNumResults(); + bool namesDisagree = getNames().size() != getNumResults(); SmallString<32> resultNameStr; - for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { + for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { resultNameStr.clear(); llvm::raw_svector_ostream tmpStream(resultNameStr); - p.printOperand(op.getResult(i), tmpStream); + p.printOperand(getResult(i), tmpStream); - auto expectedName = op.getNames()[i].dyn_cast(); + auto expectedName = getNames()[i].dyn_cast(); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; @@ -1101,9 +1099,9 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { } if (namesDisagree) - p.printOptionalAttrDictWithKeyword(op->getAttrs()); + p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); else - p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"}); + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); } // We set the SSA name in the asm syntax to the contents of the name @@ -1142,27 +1140,26 @@ LogicalResult AttrWithTraitOp::verify() { // RegionIfOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, RegionIfOp op) { +void RegionIfOp::print(OpAsmPrinter &p) { p << " "; - p.printOperands(op.getOperands()); - p << ": " << op.getOperandTypes(); - p.printArrowTypeList(op.getResultTypes()); + p.printOperands(getOperands()); + p << ": " << getOperandTypes(); + p.printArrowTypeList(getResultTypes()); p << " then "; - p.printRegion(op.getThenRegion(), + p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " else "; - p.printRegion(op.getElseRegion(), + p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " join "; - p.printRegion(op.getJoinRegion(), + p.printRegion(getJoinRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); } -static ParseResult parseRegionIfOp(OpAsmParser &parser, - OperationState &result) { +ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operandInfos; SmallVector operandTypes; @@ -1241,17 +1238,17 @@ void AnyCondOp::getRegionInvocationBounds( // SingleNoTerminatorCustomAsmOp //===----------------------------------------------------------------------===// -static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser, - OperationState &state) { +ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, + OperationState &state) { Region *body = state.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) return failure(); return success(); } -static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) { +void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { printer.printRegion( - op.getRegion(), /*printEntryBlockArgs=*/false, + getRegion(), /*printEntryBlockArgs=*/false, // This op has a single block without terminators. But explicitly mark // as not printing block terminators for testing. /*printBlockTerminators=*/false); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index efaf2e1..3a14be7 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -360,8 +360,7 @@ def SingleNoTerminatorOp : TEST_Op<"single_no_terminator_op", def SingleNoTerminatorCustomAsmOp : TEST_Op<"single_no_terminator_custom_asm_op", [SingleBlock, NoTerminator]> { let regions = (region SizedRegion<1>); - let parser = [{ return ::parseSingleNoTerminatorCustomAsmOp(parser, result); }]; - let printer = [{ return ::print(*this, p); }]; + let hasCustomAssemblyFormat = 1; } def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op", @@ -644,9 +643,7 @@ def StringAttrPrettyNameOp [DeclareOpInterfaceMethods]> { let arguments = (ins StrArrayAttr:$names); let results = (outs Variadic:$r); - - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasCustomAssemblyFormat = 1; } // This is used to test the OpAsmOpInterface::getDefaultDialect() feature: @@ -1580,14 +1577,12 @@ def TestSignatureConversionNoConverterOp def ParseIntegerLiteralOp : TEST_Op<"parse_integer_literal"> { let results = (outs Variadic:$results); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def ParseWrappedKeywordOp : TEST_Op<"parse_wrapped_keyword"> { let arguments = (ins StrAttr:$keyword); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -1602,8 +1597,7 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> { let arguments = (ins Index); let regions = (region SizedRegion<1>:$region); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def SSACFGRegionOp : TEST_Op<"ssacfg_region", [ @@ -1626,8 +1620,7 @@ def GraphRegionOp : TEST_Op<"graph_region", [ }]; let regions = (region AnyRegion:$region); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> { @@ -1637,8 +1630,7 @@ def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> { }]; let regions = (region SizedRegion<1>:$region); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def WrappingRegionOp : TEST_Op<"wrapping_region", @@ -1651,8 +1643,7 @@ def WrappingRegionOp : TEST_Op<"wrapping_region", let results = (outs Variadic); let regions = (region SizedRegion<1>:$region); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", @@ -1670,12 +1661,10 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", let results = (outs AnyType); let regions = (region SizedRegion<1>:$region); - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; } -def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> -{ +def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> { let summary = "polyfor operation"; let description = [{ Test op with multiple region arguments, each argument of index type. @@ -1685,7 +1674,7 @@ def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> mlir::OpAsmSetValueNameFn setNameFn); }]; let regions = (region SizedRegion<1>:$region); - let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -2356,8 +2345,6 @@ def RegionIfOp : TEST_Op<"region_if", parent op. }]; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseRegionIfOp(parser, result); }]; let arguments = (ins Variadic); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$thenRegion, @@ -2375,6 +2362,7 @@ def RegionIfOp : TEST_Op<"region_if", } ::mlir::OperandRange getSuccessorEntryOperands(unsigned index); }]; + let hasCustomAssemblyFormat = 1; } def AnyCondOp : TEST_Op<"any_cond", diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index fbe8438..2c974228 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -38,8 +38,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { ); let builders = [OpBuilder<(ins "Value":$val)>, OpBuilder<(ins CArg<"int", "0">:$integer)>]; - let parser = [{ foo }]; - let printer = [{ bar }]; + let hasCustomAssemblyFormat = 1; let hasCanonicalizer = 1; let hasFolder = 1; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index c11f848..7060fdf 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2125,13 +2125,29 @@ void OpEmitter::genTypeInterfaceMethods() { } void OpEmitter::genParser() { - if (!hasStringAttribute(def, "parser") || - hasStringAttribute(def, "assemblyFormat")) + if (hasStringAttribute(def, "assemblyFormat")) + return; + + bool hasCppFormat = def.getValueAsBit("hasCustomAssemblyFormat"); + if (!hasStringAttribute(def, "parser") && !hasCppFormat) return; SmallVector paramList; paramList.emplace_back("::mlir::OpAsmParser &", "parser"); paramList.emplace_back("::mlir::OperationState &", "result"); + + // If this uses the cpp format, only generate a declaration. + if (hasCppFormat) { + auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse", + std::move(paramList)); + ERROR_IF_PRUNED(method, "parse", op); + return; + } + + PrintNote(op.getLoc(), + "`parser` and `printer` fields are deprecated and will be removed, " + "please use the `hasCustomAssemblyFormat` field instead"); + auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse", std::move(paramList)); ERROR_IF_PRUNED(method, "parse", op); @@ -2146,6 +2162,14 @@ void OpEmitter::genPrinter() { if (hasStringAttribute(def, "assemblyFormat")) return; + // If this uses the cpp format, only generate a declaration. + if (def.getValueAsBit("hasCustomAssemblyFormat")) { + auto *method = opClass.declareMethod( + "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p")); + ERROR_IF_PRUNED(method, "print", op); + return; + } + auto *valueInit = def.getValueInit("printer"); StringInit *stringInit = dyn_cast(valueInit); if (!stringInit) -- 2.7.4