[mlir:ODS] Deprecate Op parser/printer fields in favor of a new hasCustomAssemblyForm...
authorRiver Riddle <riddleriver@gmail.com>
Sat, 5 Feb 2022 04:47:01 +0000 (20:47 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Feb 2022 03:03:57 +0000 (19:03 -0800)
Currently if an operation wants a C++ implemented parser/printer, it specifies inline
code blocks. This is quite problematic for various reasons, e.g. it requires defining
C++ inside of Tablegen which is discouraged when possible, but mainly because
nearly all usages simply forward to static functions (e.g. `static void parseSomeOp(...)`)
with users devising their own standards for how these are defined.

This commit adds support for a `hasCustomAssemblyFormat` bit field that specifies if
a C++ parser/printer is needed, and when set to 1 declares the parse/print methods for
operations to override. For migration purposes, the existing behavior is untouched. Upstream
usages will be replaced in a followup to keep this patch focused on the new implementation.

Differential Revision: https://reviews.llvm.org/D119054

mlir/include/mlir/IR/OpBase.td
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index e1af6d9..80a0949 100644 (file)
@@ -2442,14 +2442,24 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
   // provided.
   bit skipDefaultBuilders = 0;
 
-  // Custom parser.
+  // Custom parser and printer.
+  // NOTE: These fields are deprecated in favor of `assemblyFormat` or
+  // `hasCustomAssemblyFormat`, and are slated for deletion.
   code parser = ?;
-
-  // Custom printer.
   code printer = ?;
 
   // Custom assembly format.
+  /// This field corresponds to a declarative description of the assembly format
+  /// for this operation. If populated, the `hasCustomAssemblyFormat` field is
+  /// ignored.
   string assemblyFormat = ?;
+  /// This field indicates that the operation has a custom assembly format
+  /// implemented in C++. When set to `1` a `parse` and `print` method are generated
+  /// on the operation class. The operation should implement these methods to
+  /// support the custom format of the operation. The methods have the form:
+  ///   * ParseResult parse(OpAsmParser &parser, OperationState &result)
+  ///   * void print(OpAsmPrinter &p)
+  bit hasCustomAssemblyFormat = 0;
 
   // A bit indicating if the operation has additional invariants that need to
   // verified (aside from those verified by other ODS constructs). If set to `1`,
index e23173d..623d512 100644 (file)
@@ -577,8 +577,8 @@ static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
-                                         OperationState &result) {
+ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
   OpAsmParser::OperandType argInfo;
   Type argType = parser.getBuilder().getIndexType();
 
@@ -593,12 +593,12 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
                             /*enableNameShadowing=*/true);
 }
 
-static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
+void IsolatedRegionOp::print(OpAsmPrinter &p) {
   p << "test.isolated_region ";
-  p.printOperand(op.getOperand());
-  p.shadowRegionArgs(op.getRegion(), op.getOperand());
+  p.printOperand(getOperand());
+  p.shadowRegionArgs(getRegion(), getOperand());
   p << ' ';
-  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -613,16 +613,15 @@ RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
 // Test GraphRegionOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseGraphRegionOp(OpAsmParser &parser,
-                                      OperationState &result) {
+ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
   // Parse the body region, and reuse the operand info as the argument info.
   Region *body = result.addRegion();
   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
 }
 
-static void print(OpAsmPrinter &p, GraphRegionOp op) {
+void GraphRegionOp::print(OpAsmPrinter &p) {
   p << "test.graph_region ";
-  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
 
 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
@@ -633,24 +632,23 @@ RegionKind GraphRegionOp::getRegionKind(unsigned index) {
 // Test AffineScopeOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseAffineScopeOp(OpAsmParser &parser,
-                                      OperationState &result) {
+ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
   // Parse the body region, and reuse the operand info as the argument info.
   Region *body = result.addRegion();
   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
 }
 
-static void print(OpAsmPrinter &p, AffineScopeOp op) {
+void AffineScopeOp::print(OpAsmPrinter &p) {
   p << "test.affine_scope ";
-  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
 
 //===----------------------------------------------------------------------===//
 // Test parser.
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
-                                              OperationState &result) {
+ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
   if (parser.parseOptionalColon())
     return success();
   uint64_t numResults;
@@ -663,13 +661,13 @@ static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
   return success();
 }
 
-static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
-  if (unsigned numResults = op->getNumResults())
+void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
+  if (unsigned numResults = getNumResults())
     p << " : " << numResults;
 }
 
-static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
-                                              OperationState &result) {
+ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
   StringRef keyword;
   if (parser.parseKeyword(&keyword))
     return failure();
@@ -677,15 +675,13 @@ static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
   return success();
 }
 
-static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
-  p << " " << op.getKeyword();
-}
+void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
 
 //===----------------------------------------------------------------------===//
 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
 
-static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
-                                         OperationState &result) {
+ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
   if (parser.parseKeyword("wraps"))
     return failure();
 
@@ -715,9 +711,9 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
   return success();
 }
 
-static void print(OpAsmPrinter &p, WrappingRegionOp op) {
+void WrappingRegionOp::print(OpAsmPrinter &p) {
   p << " wraps ";
-  p.printGenericOp(&op.getRegion().front().front());
+  p.printGenericOp(&getRegion().front().front());
 }
 
 //===----------------------------------------------------------------------===//
@@ -726,8 +722,8 @@ static void print(OpAsmPrinter &p, WrappingRegionOp op) {
 //   parseCustomOperationName
 //===----------------------------------------------------------------------===//
 
-static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
-                                              OperationState &result) {
+ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
 
   SMLoc loc = parser.getCurrentLocation();
   Location currLocation = parser.getEncodedSourceLoc(loc);
@@ -799,11 +795,11 @@ static ParseResult parsePrettyPrintedRegionOp(OpAsmParser &parser,
   return success();
 }
 
-static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
+void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
   p << ' ';
-  p.printOperands(op.getOperands());
+  p.printOperands(getOperands());
 
-  Operation &innerOp = op.getRegion().front().front();
+  Operation &innerOp = getRegion().front().front();
   // Assuming that region has a single non-terminator inner-op, if the inner-op
   // meets some criteria (which in this case is a simple one  based on the name
   // of inner-op), then we can print the entire region in a succinct way.
@@ -813,19 +809,19 @@ static void print(OpAsmPrinter &p, PrettyPrintedRegionOp op) {
     p << " start special.op end";
   } else {
     p << " (";
-    p.printRegion(op.getRegion());
+    p.printRegion(getRegion());
     p << ")";
   }
 
   p << " : ";
-  p.printFunctionalType(op);
+  p.printFunctionalType(*this);
 }
 
 //===----------------------------------------------------------------------===//
 // Test PolyForOp - parse list of region arguments.
 //===----------------------------------------------------------------------===//
 
-static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
+ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
   // Parse list of region arguments without a delimiter.
   if (parser.parseRegionArgumentList(ivsInfo))
@@ -838,6 +834,8 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
   return parser.parseRegion(*body, ivsInfo, argTypes);
 }
 
+void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
+
 void PolyForOp::getAsmBlockArgumentNames(Region &region,
                                          OpAsmSetValueNameFn setNameFn) {
   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
@@ -1044,8 +1042,8 @@ void SideEffectOp::getEffects(
 //===----------------------------------------------------------------------===//
 
 // This op has fancy handling of its SSA result name.
-static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
-                                               OperationState &result) {
+ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
+                                          OperationState &result) {
   // Add the result types.
   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
     result.addTypes(parser.getBuilder().getIntegerType(32));
@@ -1081,19 +1079,19 @@ static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
   return success();
 }
 
-static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
+void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
   // Note that we only need to print the "name" attribute if the asmprinter
   // result name disagrees with it.  This can happen in strange cases, e.g.
   // when there are conflicts.
-  bool namesDisagree = op.getNames().size() != op.getNumResults();
+  bool namesDisagree = getNames().size() != getNumResults();
 
   SmallString<32> resultNameStr;
-  for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
+  for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
     resultNameStr.clear();
     llvm::raw_svector_ostream tmpStream(resultNameStr);
-    p.printOperand(op.getResult(i), tmpStream);
+    p.printOperand(getResult(i), tmpStream);
 
-    auto expectedName = op.getNames()[i].dyn_cast<StringAttr>();
+    auto expectedName = getNames()[i].dyn_cast<StringAttr>();
     if (!expectedName ||
         tmpStream.str().drop_front() != expectedName.getValue()) {
       namesDisagree = true;
@@ -1101,9 +1099,9 @@ static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
   }
 
   if (namesDisagree)
-    p.printOptionalAttrDictWithKeyword(op->getAttrs());
+    p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
   else
-    p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
+    p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
 }
 
 // We set the SSA name in the asm syntax to the contents of the name
@@ -1142,27 +1140,26 @@ LogicalResult AttrWithTraitOp::verify() {
 // RegionIfOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, RegionIfOp op) {
+void RegionIfOp::print(OpAsmPrinter &p) {
   p << " ";
-  p.printOperands(op.getOperands());
-  p << ": " << op.getOperandTypes();
-  p.printArrowTypeList(op.getResultTypes());
+  p.printOperands(getOperands());
+  p << ": " << getOperandTypes();
+  p.printArrowTypeList(getResultTypes());
   p << " then ";
-  p.printRegion(op.getThenRegion(),
+  p.printRegion(getThenRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
   p << " else ";
-  p.printRegion(op.getElseRegion(),
+  p.printRegion(getElseRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
   p << " join ";
-  p.printRegion(op.getJoinRegion(),
+  p.printRegion(getJoinRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
 }
 
-static ParseResult parseRegionIfOp(OpAsmParser &parser,
-                                   OperationState &result) {
+ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
   SmallVector<Type, 2> operandTypes;
 
@@ -1241,17 +1238,17 @@ void AnyCondOp::getRegionInvocationBounds(
 // SingleNoTerminatorCustomAsmOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseSingleNoTerminatorCustomAsmOp(OpAsmParser &parser,
-                                                      OperationState &state) {
+ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
+                                                 OperationState &state) {
   Region *body = state.addRegion();
   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
     return failure();
   return success();
 }
 
-static void print(SingleNoTerminatorCustomAsmOp op, OpAsmPrinter &printer) {
+void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
   printer.printRegion(
-      op.getRegion(), /*printEntryBlockArgs=*/false,
+      getRegion(), /*printEntryBlockArgs=*/false,
       // This op has a single block without terminators. But explicitly mark
       // as not printing block terminators for testing.
       /*printBlockTerminators=*/false);
index efaf2e1..3a14be7 100644 (file)
@@ -360,8 +360,7 @@ def SingleNoTerminatorOp : TEST_Op<"single_no_terminator_op",
 def SingleNoTerminatorCustomAsmOp : TEST_Op<"single_no_terminator_custom_asm_op",
                                             [SingleBlock, NoTerminator]> {
   let regions = (region SizedRegion<1>);
-  let parser = [{ return ::parseSingleNoTerminatorCustomAsmOp(parser, result); }];
-  let printer = [{ return ::print(*this, p); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
@@ -644,9 +643,7 @@ def StringAttrPrettyNameOp
            [DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
   let arguments = (ins StrArrayAttr:$names);
   let results = (outs Variadic<I32>:$r);
-
-  let printer = [{ return ::print(p, *this); }];
-  let parser = [{ return ::parse$cppClass(parser, result); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 // This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
@@ -1580,14 +1577,12 @@ def TestSignatureConversionNoConverterOp
 
 def ParseIntegerLiteralOp : TEST_Op<"parse_integer_literal"> {
   let results = (outs Variadic<Index>:$results);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def ParseWrappedKeywordOp : TEST_Op<"parse_wrapped_keyword"> {
   let arguments = (ins StrAttr:$keyword);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1602,8 +1597,7 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> {
 
   let arguments = (ins Index);
   let regions = (region SizedRegion<1>:$region);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def SSACFGRegionOp : TEST_Op<"ssacfg_region",  [
@@ -1626,8 +1620,7 @@ def GraphRegionOp : TEST_Op<"graph_region",  [
   }];
 
   let regions = (region AnyRegion:$region);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
@@ -1637,8 +1630,7 @@ def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
   }];
 
   let regions = (region SizedRegion<1>:$region);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def WrappingRegionOp : TEST_Op<"wrapping_region",
@@ -1651,8 +1643,7 @@ def WrappingRegionOp : TEST_Op<"wrapping_region",
 
   let results = (outs Variadic<AnyType>);
   let regions = (region SizedRegion<1>:$region);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
@@ -1670,12 +1661,10 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
 
   let results = (outs AnyType);
   let regions = (region SizedRegion<1>:$region);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
-  let printer = [{ return ::print(p, *this); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
-def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
-{
+def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> {
   let summary =  "polyfor operation";
   let description = [{
     Test op with multiple region arguments, each argument of index type.
@@ -1685,7 +1674,7 @@ def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
                                   mlir::OpAsmSetValueNameFn setNameFn);
   }];
   let regions = (region SizedRegion<1>:$region);
-  let parser = [{ return ::parse$cppClass(parser, result); }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2356,8 +2345,6 @@ def RegionIfOp : TEST_Op<"region_if",
     parent op.
   }];
 
-  let printer = [{ return ::print(p, *this); }];
-  let parser = [{ return ::parseRegionIfOp(parser, result); }];
   let arguments = (ins Variadic<AnyType>);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$thenRegion,
@@ -2375,6 +2362,7 @@ def RegionIfOp : TEST_Op<"region_if",
     }
     ::mlir::OperandRange getSuccessorEntryOperands(unsigned index);
   }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def AnyCondOp : TEST_Op<"any_cond",
index fbe8438..2c97422 100644 (file)
@@ -38,8 +38,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
   );
   let builders = [OpBuilder<(ins "Value":$val)>,
                   OpBuilder<(ins CArg<"int", "0">:$integer)>];
-  let parser = [{ foo }];
-  let printer = [{ bar }];
+  let hasCustomAssemblyFormat = 1;
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
index c11f848..7060fdf 100644 (file)
@@ -2125,13 +2125,29 @@ void OpEmitter::genTypeInterfaceMethods() {
 }
 
 void OpEmitter::genParser() {
-  if (!hasStringAttribute(def, "parser") ||
-      hasStringAttribute(def, "assemblyFormat"))
+  if (hasStringAttribute(def, "assemblyFormat"))
+    return;
+
+  bool hasCppFormat = def.getValueAsBit("hasCustomAssemblyFormat");
+  if (!hasStringAttribute(def, "parser") && !hasCppFormat)
     return;
 
   SmallVector<MethodParameter> paramList;
   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
   paramList.emplace_back("::mlir::OperationState &", "result");
+
+  // If this uses the cpp format, only generate a declaration.
+  if (hasCppFormat) {
+    auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
+                                               std::move(paramList));
+    ERROR_IF_PRUNED(method, "parse", op);
+    return;
+  }
+
+  PrintNote(op.getLoc(),
+            "`parser` and `printer` fields are deprecated and will be removed, "
+            "please use the `hasCustomAssemblyFormat` field instead");
+
   auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
                                          std::move(paramList));
   ERROR_IF_PRUNED(method, "parse", op);
@@ -2146,6 +2162,14 @@ void OpEmitter::genPrinter() {
   if (hasStringAttribute(def, "assemblyFormat"))
     return;
 
+  // If this uses the cpp format, only generate a declaration.
+  if (def.getValueAsBit("hasCustomAssemblyFormat")) {
+    auto *method = opClass.declareMethod(
+        "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
+    ERROR_IF_PRUNED(method, "print", op);
+    return;
+  }
+
   auto *valueInit = def.getValueInit("printer");
   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
   if (!stringInit)