[mlir][NFC] Make 'printOp' public in AsmPrinter
authorDiego Caballero <diegocaballero@google.com>
Wed, 5 Oct 2022 18:31:22 +0000 (18:31 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 5 Oct 2022 19:00:53 +0000 (19:00 +0000)
This patch moves the 'printOp' functionality to the public API of
AsmPrinter and rename it to 'printCustomOrGenericOp'. No 'parseOp'
is needed at this time as existing APIs are able to parse operations
producing results where results are omitted in the textual form
(the LHS of an operation is redundant when it comes to building the
operation itself as it only contains the result names).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D135006

mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/print-op-custom-or-generic.mlir [new file with mode: 0644]
mlir/test/IR/print-op-generic.mlir [deleted file]

index 78843ac..0251394 100644 (file)
@@ -380,6 +380,10 @@ public:
   printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
                                    ArrayRef<StringRef> elidedAttrs = {}) = 0;
 
+  /// Prints the entire operation with the custom assembly form, if available,
+  /// or the generic assembly form, otherwise.
+  virtual void printCustomOrGenericOp(Operation *op) = 0;
+
   /// Print the entire operation with the default generic assembly form.
   /// If `printOpName` is true, then the operation name is printed (the default)
   /// otherwise it is omitted and the print will start with the operand list.
index f51ea60..53da51c 100644 (file)
@@ -421,8 +421,9 @@ public:
                                       AliasInitializer &initializer)
       : printerFlags(printerFlags), initializer(initializer) {}
 
-  /// Print the given operation.
-  void print(Operation *op) {
+  /// Prints the entire operation with the custom assembly form, if available,
+  /// or the generic assembly form, otherwise.
+  void printCustomOrGenericOp(Operation *op) override {
     // Visit the operation location.
     if (printerFlags.shouldPrintDebugInfo())
       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
@@ -489,7 +490,7 @@ private:
         std::prev(block->end(),
                   (!hasTerminator || printBlockTerminator) ? 0 : 1));
     for (Operation &op : range)
-      print(&op);
+      printCustomOrGenericOp(&op);
   }
 
   /// Print the given region.
@@ -680,7 +681,7 @@ void AliasInitializer::initialize(
   // attributes/types that will actually be used during printing when
   // considering aliases.
   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
-  aliasPrinter.print(op);
+  aliasPrinter.printCustomOrGenericOp(op);
 
   // Initialize the aliases sorted by name.
   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
@@ -2660,11 +2661,16 @@ public:
   /// Print the given top-level operation.
   void printTopLevelOperation(Operation *op);
 
-  /// Print the given operation with its indent and location.
-  void print(Operation *op);
-  /// Print the bare location, not including indentation/location/etc.
-  void printOperation(Operation *op);
-  /// Print the given operation in the generic form.
+  /// Print the given operation, including its left-hand side and its right-hand
+  /// side, with its indent and location.
+  void printFullOpWithIndentAndLoc(Operation *op);
+  /// Print the given operation, including its left-hand side and its right-hand
+  /// side, but not including indentation and location.
+  void printFullOp(Operation *op);
+  /// Print the right-hand size of the given operation in the custom or generic
+  /// form.
+  void printCustomOrGenericOp(Operation *op) override;
+  /// Print the right-hand side of the given operation in the generic form.
   void printGenericOp(Operation *op, bool printOpName) override;
 
   /// Print the name of the given block.
@@ -2838,7 +2844,7 @@ void OperationPrinter::printTopLevelOperation(Operation *op) {
   state.getAliasState().printNonDeferredAliases(os, newLine);
 
   // Print the module.
-  print(op);
+  printFullOpWithIndentAndLoc(op);
   os << newLine;
 
   // Output the aliases at the top level that can be deferred.
@@ -2934,18 +2940,18 @@ void OperationPrinter::printRegionArgument(BlockArgument arg,
   printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
 }
 
-void OperationPrinter::print(Operation *op) {
+void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
   // Track the location of this operation.
   state.registerOperationLocation(op, newLine.curLine, currentIndent);
 
   os.indent(currentIndent);
-  printOperation(op);
+  printFullOp(op);
   printTrailingLocation(op->getLoc());
   if (printerFlags.shouldPrintValueUsers())
     printUsersComment(op);
 }
 
-void OperationPrinter::printOperation(Operation *op) {
+void OperationPrinter::printFullOp(Operation *op) {
   if (size_t numResults = op->getNumResults()) {
     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
@@ -2972,34 +2978,7 @@ void OperationPrinter::printOperation(Operation *op) {
     os << " = ";
   }
 
-  // If requested, always print the generic form.
-  if (!printerFlags.shouldPrintGenericOpForm()) {
-    // Check to see if this is a known operation. If so, use the registered
-    // custom printer hook.
-    if (auto opInfo = op->getRegisteredInfo()) {
-      opInfo->printAssembly(op, *this, defaultDialectStack.back());
-      return;
-    }
-    // Otherwise try to dispatch to the dialect, if available.
-    if (Dialect *dialect = op->getDialect()) {
-      if (auto opPrinter = dialect->getOperationPrinter(op)) {
-        // Print the op name first.
-        StringRef name = op->getName().getStringRef();
-        // Only drop the default dialect prefix when it cannot lead to
-        // ambiguities.
-        if (name.count('.') == 1)
-          name.consume_front((defaultDialectStack.back() + ".").str());
-        os << name;
-
-        // Print the rest of the op now.
-        opPrinter(op, *this);
-        return;
-      }
-    }
-  }
-
-  // Otherwise print with the generic assembly form.
-  printGenericOp(op, /*printOpName=*/true);
+  printCustomOrGenericOp(op);
 }
 
 void OperationPrinter::printUsersComment(Operation *op) {
@@ -3076,6 +3055,37 @@ void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
   }
 }
 
+void OperationPrinter::printCustomOrGenericOp(Operation *op) {
+  // If requested, always print the generic form.
+  if (!printerFlags.shouldPrintGenericOpForm()) {
+    // Check to see if this is a known operation. If so, use the registered
+    // custom printer hook.
+    if (auto opInfo = op->getRegisteredInfo()) {
+      opInfo->printAssembly(op, *this, defaultDialectStack.back());
+      return;
+    }
+    // Otherwise try to dispatch to the dialect, if available.
+    if (Dialect *dialect = op->getDialect()) {
+      if (auto opPrinter = dialect->getOperationPrinter(op)) {
+        // Print the op name first.
+        StringRef name = op->getName().getStringRef();
+        // Only drop the default dialect prefix when it cannot lead to
+        // ambiguities.
+        if (name.count('.') == 1)
+          name.consume_front((defaultDialectStack.back() + ".").str());
+        os << name;
+
+        // Print the rest of the op now.
+        opPrinter(op, *this);
+        return;
+      }
+    }
+  }
+
+  // Otherwise print with the generic assembly form.
+  printGenericOp(op, /*printOpName=*/true);
+}
+
 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
   if (printOpName)
     printEscapedString(op->getName().getStringRef());
@@ -3176,7 +3186,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
       std::prev(block->end(),
                 (!hasTerminator || printBlockTerminator) ? 0 : 1));
   for (auto &op : range) {
-    print(&op);
+    printFullOpWithIndentAndLoc(&op);
     os << newLine;
   }
   currentIndent -= indentWidth;
@@ -3418,7 +3428,7 @@ void Operation::print(raw_ostream &os, AsmState &state) {
     state.getImpl().initializeAliases(this);
     printer.printTopLevelOperation(this);
   } else {
-    printer.print(this);
+    printer.printFullOpWithIndentAndLoc(this);
   }
 }
 
diff --git a/mlir/test/IR/print-op-custom-or-generic.mlir b/mlir/test/IR/print-op-custom-or-generic.mlir
new file mode 100644 (file)
index 0000000..a82089b
--- /dev/null
@@ -0,0 +1,28 @@
+// # RUN: mlir-opt %s -split-input-file | FileCheck %s
+// # RUN: mlir-opt %s -mlir-print-op-generic -split-input-file  | FileCheck %s --check-prefix=GENERIC
+
+// Check that `printCustomOrGenericOp` and `printGenericOp` print the right
+// assembly format. For operations without custom format, both should print the
+// generic format.
+
+// CHECK-LABEL: func @op_with_custom_printer
+// CHECK-GENERIC-LABEL: "func"()
+func.func @op_with_custom_printer() {
+  %x = test.string_attr_pretty_name
+  // CHECK: %x = test.string_attr_pretty_name
+  // GENERIC: %0 = "test.string_attr_pretty_name"()
+  return
+  // CHECK: return
+  // GENERIC: "func.return"()
+}
+
+// -----
+
+// CHECK-LABEL: func @op_without_custom_printer
+// CHECK-GENERIC: "func"()
+func.func @op_without_custom_printer() {
+  // CHECK: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  // GENERIC: "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  "test.result_type_with_trait"() : () -> !test.test_type_with_trait
+  return
+}
diff --git a/mlir/test/IR/print-op-generic.mlir b/mlir/test/IR/print-op-generic.mlir
deleted file mode 100644 (file)
index ed34f84..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-// # RUN: mlir-opt %s | FileCheck %s
-// # RUN: mlir-opt %s --mlir-print-op-generic  | FileCheck %s --check-prefix=GENERIC
-
-// CHECK-LABEL: func @pretty_names
-// CHECK-GENERIC: "func"()
-func.func @pretty_names() {
-  %x = test.string_attr_pretty_name
-  // CHECK: %x = test.string_attr_pretty_name
-  // GENERIC: %0 = "test.string_attr_pretty_name"()
-  return
-  // CHECK: return
-  // GENERIC: "func.return"()
-}