Convert the dialect type parse/print hooks into virtual functions on the Dialect...
authorRiver Riddle <riverriddle@google.com>
Mon, 25 Feb 2019 21:16:24 +0000 (13:16 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:42:05 +0000 (16:42 -0700)
PiperOrigin-RevId: 235589945

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/LLVMIR/LLVMDialect.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Parser/Parser.cpp

index 55b6f7efd3650e12eda2cd9aaa2d57d86047a97c..067fe53dad3820d8cfa79191c3c02c5388bbac84 100644 (file)
@@ -35,9 +35,6 @@ using DialectConstantFoldHook = std::function<bool(
     const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
 using DialectExtractElementHook =
     std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
-using DialectTypeParserHook =
-    std::function<Type(StringRef, Location, MLIRContext *)>;
-using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
 
 /// Dialects are groups of MLIR operations and behavior associated with the
 /// entire group.  For example, hooks into other systems for constant folding,
@@ -80,11 +77,16 @@ public:
         return Attribute();
       };
 
-  /// Registered parsing/printing hooks for types registered to the dialect.
-  DialectTypeParserHook typeParseHook = nullptr;
+  /// Parse a type registered to this dialect.
+  virtual Type parseType(StringRef tyData, Location loc,
+                         MLIRContext *context) const;
+
+  /// Print a type registered to this dialect.
   /// Note: The data printed for the provided type must not include any '"'
   /// characters.
-  DialectTypePrinterHook typePrintHook = nullptr;
+  virtual void printType(Type, raw_ostream &) const {
+    assert(0 && "dialect has no registered type printing hook");
+  }
 
   /// Registered hooks for getting identifier aliases for symbols. The
   /// identifier is used in place of the symbol when printing textual IR.
index cd2b5c9d7085acb13cc3fbdc65b1c65eb4aa5551..6c1716597967c9fd30dcbe8d65cf5674529b3059 100644 (file)
@@ -76,6 +76,13 @@ public:
   llvm::LLVMContext &getLLVMContext() { return llvmContext; }
   llvm::Module &getLLVMModule() { return module; }
 
+  /// Parse a type registered to this dialect.
+  Type parseType(StringRef tyData, Location loc,
+                 MLIRContext *context) const override;
+
+  /// Print a type registered to this dialect.
+  void printType(Type type, raw_ostream &os) const override;
+
 private:
   llvm::LLVMContext llvmContext;
   llvm::Module module;
index 5348125577d8968054d15aa561c70f0167df5044..9af1794cb0523752cb7728477a984738672309b0 100644 (file)
@@ -715,8 +715,7 @@ void ModulePrinter::printType(Type type) {
   default: {
     auto &dialect = type.getDialect();
     os << '!' << dialect.getNamespace() << "<\"";
-    assert(dialect.typePrintHook && "Expected dialect type printing hook.");
-    dialect.typePrintHook(type, os);
+    dialect.printType(type, os);
     os << "\">";
     return;
   }
index 338c918c33965c7134a9da4325b9428405dc1f51..c24d6b1f388d4406814dea2d689ec164246d6efc 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectHooks.h"
 #include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/Support/ManagedStatic.h"
 using namespace mlir;
 
@@ -65,3 +66,11 @@ Dialect::Dialect(StringRef namePrefix, MLIRContext *context)
 }
 
 Dialect::~Dialect() {}
+
+/// Parse a type registered to this dialect.
+Type Dialect::parseType(StringRef tyData, Location loc,
+                        MLIRContext *context) const {
+  context->emitError(loc, "dialect '" + getNamespace() +
+                              "' provides no type parsing hook");
+  return Type();
+}
index 9c3d31da9ce838adc044d2c892b8a4acbdec662a..3444b0ee4c7d2e7f8c92bc3c749f260669c4e8de 100644 (file)
@@ -57,27 +57,6 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
   return Base::get(context, FIRST_LLVM_TYPE, llvmType);
 }
 
-static Type parseLLVMType(StringRef data, Location loc, MLIRContext *ctx) {
-  llvm::SMDiagnostic errorMessage;
-  auto *llvmDialect =
-      static_cast<LLVMDialect *>(ctx->getRegisteredDialect("llvm"));
-  assert(llvmDialect && "LLVM dialect not registered");
-  llvm::Type *type =
-      llvm::parseType(data, errorMessage, llvmDialect->getLLVMModule());
-  if (!type) {
-    ctx->emitError(loc, errorMessage.getMessage());
-    return {};
-  }
-  return LLVMType::get(ctx, type);
-}
-
-static void printLLVMType(Type ty, raw_ostream &os) {
-  auto type = ty.dyn_cast<LLVMType>();
-  assert(type && "printing wrong type");
-  assert(type.getUnderlyingType() && "no underlying LLVM type");
-  type.getUnderlyingType()->print(os);
-}
-
 llvm::Type *LLVMType::getUnderlyingType() const {
   return static_cast<ImplType *>(type)->underlyingType;
 }
@@ -91,9 +70,24 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
   addOperations<
 #include "mlir/LLVMIR/llvm_ops.inc"
       >();
+}
+
+/// Parse a type registered to this dialect.
+Type LLVMDialect::parseType(StringRef tyData, Location loc,
+                            MLIRContext *context) const {
+  llvm::SMDiagnostic errorMessage;
+  llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
+  if (!type)
+    return (context->emitError(loc, errorMessage.getMessage()), nullptr);
+  return LLVMType::get(context, type);
+}
 
-  typeParseHook = parseLLVMType;
-  typePrintHook = printLLVMType;
+/// Print a type registered to this dialect.
+void LLVMDialect::printType(Type type, raw_ostream &os) const {
+  auto llvmType = type.dyn_cast<LLVMType>();
+  assert(llvmType && "printing wrong type");
+  assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
+  llvmType.getUnderlyingType()->print(os);
 }
 
 static DialectRegistration<LLVMDialect> llvmDialect;
index f7ab6f1fe119850ec800f7df86b3d94fa77240f0..ce37f56d9aafb886e89a1b34ac856cc076dcf22f 100644 (file)
@@ -496,15 +496,7 @@ Type Parser::parseExtendedType() {
     return aliasIt->second;
   }
 
-  // Otherwise, check for a registered dialect with this name.
-  auto *dialect = state.context->getRegisteredDialect(identifier);
-  if (dialect) {
-    // Make sure that the dialect provides a parsing hook.
-    if (!dialect->typeParseHook)
-      return (emitError("dialect '" + dialect->getNamespace() +
-                        "' provides no type parsing hook"),
-              nullptr);
-  }
+  // Otherwise, we are parsing a dialect-specific type.
 
   // Consume the '<'.
   if (parseToken(Token::less, "expected '<' in dialect type"))
@@ -522,8 +514,8 @@ Type Parser::parseExtendedType() {
   Type result;
 
   // If we found a registered dialect, then ask it to parse the type.
-  if (dialect) {
-    result = dialect->typeParseHook(typeData, loc, state.context);
+  if (auto *dialect = state.context->getRegisteredDialect(identifier)) {
+    result = dialect->parseType(typeData, loc, state.context);
     if (!result)
       return nullptr;
   } else {