From fd87963eee23f6cf2aed97bf182a6b3f5e9450ed Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 28 Aug 2021 03:02:55 +0000 Subject: [PATCH] Change dialect `printOperation()` hook to `getOperationPrinter()` This makes the hook return a printer if available, instead of using LogicalResult to indicate if a printer was available (and invoked). This allows the caller to detect that the dialect has a printer for a given operation without actually invoking the printer. It'll be leveraged in a future revision to move printing the op name itself under control of the ASMPrinter. Differential Revision: https://reviews.llvm.org/D108803 --- mlir/include/mlir/IR/Dialect.h | 13 +++++-------- mlir/lib/IR/AsmPrinter.cpp | 4 +++- mlir/lib/IR/Dialect.cpp | 6 +++--- mlir/test/lib/Dialect/Test/TestDialect.cpp | 11 ++++++----- mlir/test/lib/Dialect/Test/TestOps.td | 8 +++++--- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index f615819..14114e3 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -121,8 +121,8 @@ public: /// Print an operation registered to this dialect. /// This hook is invoked for registered operation which don't override the /// `print()` method to define their own custom assembly. - virtual LogicalResult printOperation(Operation *op, - OpAsmPrinter &printer) const; + virtual llvm::unique_function + getOperationPrinter(Operation *op) const; //===--------------------------------------------------------------------===// // Verification Hooks @@ -297,8 +297,7 @@ class DialectRegistry { public: explicit DialectRegistry(); - template - void insert() { + template void insert() { insert(TypeID::get(), ConcreteDialect::getDialectNamespace(), static_cast(([](MLIRContext *ctx) { @@ -364,8 +363,7 @@ public: /// Add an external op interface model for an op that belongs to a dialect, /// both provided as template parameters. The dialect must be present in the /// registry. - template - void addOpInterface() { + template void addOpInterface() { StringRef opName = OpTy::getOperationName(); StringRef dialectName = opName.split('.').first; addObjectInterface(dialectName, TypeID::get(), @@ -426,8 +424,7 @@ private: namespace llvm { /// Provide isa functionality for Dialects. -template -struct isa_impl { +template struct isa_impl { static inline bool doit(const ::mlir::Dialect &dialect) { return mlir::TypeID::get() == dialect.getTypeID(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 65cbc8a..b0fafc3 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2508,8 +2508,10 @@ void OperationPrinter::printOperation(Operation *op) { } // Otherwise try to dispatch to the dialect, if available. if (Dialect *dialect = op->getDialect()) { - if (succeeded(dialect->printOperation(op, *this))) + if (auto opPrinter = dialect->getOperationPrinter(op)) { + opPrinter(op, *this); return; + } } } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 80c8dab..2f2997e 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -172,11 +172,11 @@ Dialect::getParseOperationHook(StringRef opName) const { return None; } -LogicalResult Dialect::printOperation(Operation *op, - OpAsmPrinter &printer) const { +llvm::unique_function +Dialect::getOperationPrinter(Operation *op) const { assert(op->getDialect() == this && "Dialect hook invoked on non-dialect owned operation"); - return failure(); + return nullptr; } /// Utility function that returns if the given string is a valid dialect diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 2a1f371..d12c61a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -313,14 +313,15 @@ TestDialect::getParseOperationHook(StringRef opName) const { return None; } -LogicalResult TestDialect::printOperation(Operation *op, - OpAsmPrinter &printer) const { +llvm::unique_function +TestDialect::getOperationPrinter(Operation *op) const { StringRef opName = op->getName().getStringRef(); if (opName == "test.dialect_custom_printer") { - printer.getStream() << opName << " custom_format"; - return success(); + return [](Operation *op, OpAsmPrinter &printer) { + printer.getStream() << op->getName().getStringRef() << " custom_format"; + }; } - return failure(); + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index fbbc766..4dee0a1 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -39,15 +39,17 @@ def Test_Dialect : Dialect { void registerTypes(); ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, - ::mlir::Type type) const override; + ::mlir::Type type) const override; void printAttribute(::mlir::Attribute attr, ::mlir::DialectAsmPrinter &printer) const override; // Provides a custom printing/parsing for some operations. ::llvm::Optional getParseOperationHook(::llvm::StringRef opName) const override; - ::mlir::LogicalResult printOperation(::mlir::Operation *op, - ::mlir::OpAsmPrinter &printer) const override; + ::llvm::unique_function + getOperationPrinter(::mlir::Operation *op) const override; + private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; -- 2.7.4