Change dialect `printOperation()` hook to `getOperationPrinter()`
authorMehdi Amini <joker.eph@gmail.com>
Sat, 28 Aug 2021 03:02:55 +0000 (03:02 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 31 Aug 2021 17:52:39 +0000 (17:52 +0000)
This makes the hook return a printer if available, instead of using LogicalResult  to
indicate if a printer was available (and invoked). This allows the caller to detect that
the dialect has a printer for a given operation without actually invoking the printer.
It'll be leveraged in a future revision to move printing the op name itself under control
of the ASMPrinter.

Differential Revision: https://reviews.llvm.org/D108803

mlir/include/mlir/IR/Dialect.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Dialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td

index f615819..14114e3 100644 (file)
@@ -121,8 +121,8 @@ public:
   /// Print an operation registered to this dialect.
   /// This hook is invoked for registered operation which don't override the
   /// `print()` method to define their own custom assembly.
-  virtual LogicalResult printOperation(Operation *op,
-                                       OpAsmPrinter &printer) const;
+  virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
+  getOperationPrinter(Operation *op) const;
 
   //===--------------------------------------------------------------------===//
   // Verification Hooks
@@ -297,8 +297,7 @@ class DialectRegistry {
 public:
   explicit DialectRegistry();
 
-  template <typename ConcreteDialect>
-  void insert() {
+  template <typename ConcreteDialect> void insert() {
     insert(TypeID::get<ConcreteDialect>(),
            ConcreteDialect::getDialectNamespace(),
            static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
@@ -364,8 +363,7 @@ public:
   /// Add an external op interface model for an op that belongs to a dialect,
   /// both provided as template parameters. The dialect must be present in the
   /// registry.
-  template <typename OpTy, typename ModelTy>
-  void addOpInterface() {
+  template <typename OpTy, typename ModelTy> void addOpInterface() {
     StringRef opName = OpTy::getOperationName();
     StringRef dialectName = opName.split('.').first;
     addObjectInterface(dialectName, TypeID::get<OpTy>(),
@@ -426,8 +424,7 @@ private:
 
 namespace llvm {
 /// Provide isa functionality for Dialects.
-template <typename T>
-struct isa_impl<T, ::mlir::Dialect> {
+template <typename T> struct isa_impl<T, ::mlir::Dialect> {
   static inline bool doit(const ::mlir::Dialect &dialect) {
     return mlir::TypeID::get<T>() == dialect.getTypeID();
   }
index 65cbc8a..b0fafc3 100644 (file)
@@ -2508,8 +2508,10 @@ void OperationPrinter::printOperation(Operation *op) {
     }
     // Otherwise try to dispatch to the dialect, if available.
     if (Dialect *dialect = op->getDialect()) {
-      if (succeeded(dialect->printOperation(op, *this)))
+      if (auto opPrinter = dialect->getOperationPrinter(op)) {
+        opPrinter(op, *this);
         return;
+      }
     }
   }
 
index 80c8dab..2f2997e 100644 (file)
@@ -172,11 +172,11 @@ Dialect::getParseOperationHook(StringRef opName) const {
   return None;
 }
 
-LogicalResult Dialect::printOperation(Operation *op,
-                                      OpAsmPrinter &printer) const {
+llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
+Dialect::getOperationPrinter(Operation *op) const {
   assert(op->getDialect() == this &&
          "Dialect hook invoked on non-dialect owned operation");
-  return failure();
+  return nullptr;
 }
 
 /// Utility function that returns if the given string is a valid dialect
index 2a1f371..d12c61a 100644 (file)
@@ -313,14 +313,15 @@ TestDialect::getParseOperationHook(StringRef opName) const {
   return None;
 }
 
-LogicalResult TestDialect::printOperation(Operation *op,
-                                          OpAsmPrinter &printer) const {
+llvm::unique_function<void(Operation *, OpAsmPrinter &)>
+TestDialect::getOperationPrinter(Operation *op) const {
   StringRef opName = op->getName().getStringRef();
   if (opName == "test.dialect_custom_printer") {
-    printer.getStream() << opName << " custom_format";
-    return success();
+    return [](Operation *op, OpAsmPrinter &printer) {
+      printer.getStream() << op->getName().getStringRef() << " custom_format";
+    };
   }
-  return failure();
+  return {};
 }
 
 //===----------------------------------------------------------------------===//
index fbbc766..4dee0a1 100644 (file)
@@ -39,15 +39,17 @@ def Test_Dialect : Dialect {
     void registerTypes();
 
     ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
-                             ::mlir::Type type) const override;
+                                     ::mlir::Type type) const override;
     void printAttribute(::mlir::Attribute attr,
                         ::mlir::DialectAsmPrinter &printer) const override;
 
     // Provides a custom printing/parsing for some operations.
     ::llvm::Optional<ParseOpHook>
       getParseOperationHook(::llvm::StringRef opName) const override;
-    ::mlir::LogicalResult printOperation(::mlir::Operation *op,
-                                 ::mlir::OpAsmPrinter &printer) const override;
+    ::llvm::unique_function<void(::mlir::Operation *,
+                                 ::mlir::OpAsmPrinter &printer)>
+     getOperationPrinter(::mlir::Operation *op) const override;
+
   private:
     // Storage for a custom fallback interface.
     void *fallbackEffectOpInterfaces;