const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
-using DialectTypeParserHook =
- std::function<Type(StringRef, Location, MLIRContext *)>;
-using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
return Attribute();
};
- /// Registered parsing/printing hooks for types registered to the dialect.
- DialectTypeParserHook typeParseHook = nullptr;
+ /// Parse a type registered to this dialect.
+ virtual Type parseType(StringRef tyData, Location loc,
+ MLIRContext *context) const;
+
+ /// Print a type registered to this dialect.
/// Note: The data printed for the provided type must not include any '"'
/// characters.
- DialectTypePrinterHook typePrintHook = nullptr;
+ virtual void printType(Type, raw_ostream &) const {
+ assert(0 && "dialect has no registered type printing hook");
+ }
/// Registered hooks for getting identifier aliases for symbols. The
/// identifier is used in place of the symbol when printing textual IR.
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
llvm::Module &getLLVMModule() { return module; }
+ /// Parse a type registered to this dialect.
+ Type parseType(StringRef tyData, Location loc,
+ MLIRContext *context) const override;
+
+ /// Print a type registered to this dialect.
+ void printType(Type type, raw_ostream &os) const override;
+
private:
llvm::LLVMContext llvmContext;
llvm::Module module;
default: {
auto &dialect = type.getDialect();
os << '!' << dialect.getNamespace() << "<\"";
- assert(dialect.typePrintHook && "Expected dialect type printing hook.");
- dialect.typePrintHook(type, os);
+ dialect.printType(type, os);
os << "\">";
return;
}
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
}
Dialect::~Dialect() {}
+
+/// Parse a type registered to this dialect.
+Type Dialect::parseType(StringRef tyData, Location loc,
+ MLIRContext *context) const {
+ context->emitError(loc, "dialect '" + getNamespace() +
+ "' provides no type parsing hook");
+ return Type();
+}
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
}
-static Type parseLLVMType(StringRef data, Location loc, MLIRContext *ctx) {
- llvm::SMDiagnostic errorMessage;
- auto *llvmDialect =
- static_cast<LLVMDialect *>(ctx->getRegisteredDialect("llvm"));
- assert(llvmDialect && "LLVM dialect not registered");
- llvm::Type *type =
- llvm::parseType(data, errorMessage, llvmDialect->getLLVMModule());
- if (!type) {
- ctx->emitError(loc, errorMessage.getMessage());
- return {};
- }
- return LLVMType::get(ctx, type);
-}
-
-static void printLLVMType(Type ty, raw_ostream &os) {
- auto type = ty.dyn_cast<LLVMType>();
- assert(type && "printing wrong type");
- assert(type.getUnderlyingType() && "no underlying LLVM type");
- type.getUnderlyingType()->print(os);
-}
-
llvm::Type *LLVMType::getUnderlyingType() const {
return static_cast<ImplType *>(type)->underlyingType;
}
addOperations<
#include "mlir/LLVMIR/llvm_ops.inc"
>();
+}
+
+/// Parse a type registered to this dialect.
+Type LLVMDialect::parseType(StringRef tyData, Location loc,
+ MLIRContext *context) const {
+ llvm::SMDiagnostic errorMessage;
+ llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
+ if (!type)
+ return (context->emitError(loc, errorMessage.getMessage()), nullptr);
+ return LLVMType::get(context, type);
+}
- typeParseHook = parseLLVMType;
- typePrintHook = printLLVMType;
+/// Print a type registered to this dialect.
+void LLVMDialect::printType(Type type, raw_ostream &os) const {
+ auto llvmType = type.dyn_cast<LLVMType>();
+ assert(llvmType && "printing wrong type");
+ assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
+ llvmType.getUnderlyingType()->print(os);
}
static DialectRegistration<LLVMDialect> llvmDialect;
return aliasIt->second;
}
- // Otherwise, check for a registered dialect with this name.
- auto *dialect = state.context->getRegisteredDialect(identifier);
- if (dialect) {
- // Make sure that the dialect provides a parsing hook.
- if (!dialect->typeParseHook)
- return (emitError("dialect '" + dialect->getNamespace() +
- "' provides no type parsing hook"),
- nullptr);
- }
+ // Otherwise, we are parsing a dialect-specific type.
// Consume the '<'.
if (parseToken(Token::less, "expected '<' in dialect type"))
Type result;
// If we found a registered dialect, then ask it to parse the type.
- if (dialect) {
- result = dialect->typeParseHook(typeData, loc, state.context);
+ if (auto *dialect = state.context->getRegisteredDialect(identifier)) {
+ result = dialect->parseType(typeData, loc, state.context);
if (!result)
return nullptr;
} else {