[mlir][ir] Custom ops' parse/print fall back to dialect hooks
authorMogball <jeffniu22@gmail.com>
Fri, 10 Dec 2021 00:47:48 +0000 (00:47 +0000)
committerMogball <jeffniu22@gmail.com>
Fri, 10 Dec 2021 19:34:25 +0000 (19:34 +0000)
Custom ops that have no parser or printer should fall back to the dialect's parser and/or printer hooks. This avoids the need to define parsers and printers that simply dispatch to the dialect hook.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D115481

mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/test/IR/parser.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td

index ad461e8..c68c6b1 100644 (file)
@@ -173,13 +173,19 @@ protected:
   /// back to this one which accepts everything.
   LogicalResult verify() { return success(); }
 
-  /// Unless overridden, the custom assembly form of an op is always rejected.
-  /// Op implementations should implement this to return failure.
-  /// On success, they should fill in result with the fields to use.
+  /// Parse the custom form of an operation. Unless overridden, this method will
+  /// first try to get an operation parser from the op's dialect. Otherwise the
+  /// custom assembly form of an op is always rejected. Op implementations
+  /// should implement this to return failure. On success, they should fill in
+  /// result with the fields to use.
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
 
-  // The fallback for the printer is to print it the generic assembly form.
-  static void print(Operation *op, OpAsmPrinter &p);
+  /// Print the operation. Unless overridden, this method will first try to get
+  /// an operation printer from the dialect. Otherwise, it prints the operation
+  /// in generic form.
+  static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
+
+  /// Print an operation name, eliding the dialect prefix if necessary.
   static void printOpName(Operation *op, OpAsmPrinter &p,
                           StringRef defaultDialect);
 
@@ -1781,7 +1787,7 @@ private:
                           OperationName::PrintAssemblyFn>
   getPrintAssemblyFnImpl() {
     return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
-      return OpState::print(op, printer);
+      return OpState::print(op, printer, defaultDialect);
     };
   }
   /// The internal implementation of `getPrintAssemblyFn` that is invoked when
index 164aeff..dc224f2 100644 (file)
@@ -580,14 +580,27 @@ Operation *Operation::clone() {
 // OpState trait class.
 //===----------------------------------------------------------------------===//
 
-// The fallback for the parser is to reject the custom assembly form.
+// The fallback for the parser is to try for a dialect operation parser.
+// Otherwise, reject the custom assembly form.
 ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) {
+  if (auto parseFn = result.name.getDialect()->getParseOperationHook(
+          result.name.getStringRef()))
+    return (*parseFn)(parser, result);
   return parser.emitError(parser.getNameLoc(), "has no custom assembly form");
 }
 
-// The fallback for the printer is to print in the generic assembly form.
-void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); }
-// The fallback for the printer is to print in the generic assembly form.
+// The fallback for the printer is to try for a dialect operation printer.
+// Otherwise, it prints the generic form.
+void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
+  if (auto printFn = op->getDialect()->getOperationPrinter(op)) {
+    printOpName(op, p, defaultDialect);
+    printFn(op, p);
+  } else {
+    p.printGenericOp(op);
+  }
+}
+
+/// Print an operation name, eliding the dialect prefix if necessary.
 void OpState::printOpName(Operation *op, OpAsmPrinter &p,
                           StringRef defaultDialect) {
   StringRef name = op->getName().getStringRef();
index 8f2f870..30c273f 100644 (file)
@@ -1425,3 +1425,8 @@ test.graph_region {
 // This is an unregister operation, the printing/parsing is handled by the dialect.
 // CHECK: test.dialect_custom_printer custom_format
 test.dialect_custom_printer custom_format
+
+// This is a registered operation with no custom parser and printer, and should
+// be handled by the dialect.
+// CHECK: test.dialect_custom_format_fallback custom_format_fallback
+test.dialect_custom_format_fallback custom_format_fallback
index 73d4243..a6b317d 100644 (file)
@@ -318,6 +318,11 @@ TestDialect::getParseOperationHook(StringRef opName) const {
       return parser.parseKeyword("custom_format");
     }};
   }
+  if (opName == "test.dialect_custom_format_fallback") {
+    return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
+      return parser.parseKeyword("custom_format_fallback");
+    }};
+  }
   return None;
 }
 
@@ -329,6 +334,11 @@ TestDialect::getOperationPrinter(Operation *op) const {
       printer.getStream() << " custom_format";
     };
   }
+  if (opName == "test.dialect_custom_format_fallback") {
+    return [](Operation *op, OpAsmPrinter &printer) {
+      printer.getStream() << " custom_format_fallback";
+    };
+  }
   return {};
 }
 
index 4f6abed..120749e 100644 (file)
@@ -597,6 +597,10 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
   );
 }
 
+// This is used to test that the fallback for a custom op's parser and printer
+// is the dialect parser and printer hooks.
+def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
+
 // This is used to test encoding of a string attribute into an SSA name of a
 // pretty printed value name.
 def StringAttrPrettyNameOp