From 213c6cdf2e7a30d722cee4cd66b7d48fc396d44b Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 28 May 2020 08:08:20 +0000 Subject: [PATCH] Harden MLIR detection of misconfiguration when missing dialect registration This changes will catch error where C++ op are used without being registered, either through creation with the OpBuilder or when trying to cast to the C++ op. Differential Revision: https://reviews.llvm.org/D80651 --- mlir/include/mlir/IR/Builders.h | 8 ++++++++ mlir/include/mlir/IR/MLIRContext.h | 3 +++ mlir/include/mlir/IR/OpDefinition.h | 5 ++++- mlir/lib/IR/MLIRContext.cpp | 12 ++++++++++-- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 424eb98..0dcf4da 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -374,6 +374,10 @@ public: template OpTy create(Location location, Args &&... args) { OperationState state(location, OpTy::getOperationName()); + if (!state.name.getAbstractOperation()) + llvm::report_fatal_error("Building op `" + + state.name.getStringRef().str() + + "` but it isn't registered in this MLIRContext"); OpTy::build(*this, state, std::forward(args)...); auto *op = createOperation(state); auto result = dyn_cast(op); @@ -390,6 +394,10 @@ public: // Create the operation without using 'createOperation' as we don't want to // insert it yet. OperationState state(location, OpTy::getOperationName()); + if (!state.name.getAbstractOperation()) + llvm::report_fatal_error("Building op `" + + state.name.getStringRef().str() + + "` but it isn't registered in this MLIRContext"); OpTy::build(*this, state, std::forward(args)...); Operation *op = Operation::create(state); diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index da0b0bd..8e75bb6 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -85,6 +85,9 @@ public: /// directly. std::vector getRegisteredOperations(); + /// Return true if this operation name is registered in this context. + bool isOperationRegistered(StringRef name); + // This is effectively private given that only MLIRContext.cpp can see the // MLIRContextImpl type. MLIRContextImpl &getImpl() { return *impl; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index bf5bd70..e92d54e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1235,7 +1235,10 @@ public: static bool classof(Operation *op) { if (auto *abstractOp = op->getAbstractOperation()) return TypeID::get() == abstractOp->typeID; - return op->getName().getStringRef() == ConcreteType::getOperationName(); + assert(op->getContext()->isOperationRegistered( + ConcreteType::getOperationName()) && + "Casting attempt to an unregistered operation"); + return false; } /// This is the hook used by the AsmParser to parse the custom form of this diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0728f29..da607a2 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -543,6 +543,13 @@ std::vector MLIRContext::getRegisteredOperations() { return result; } +bool MLIRContext::isOperationRegistered(StringRef name) { + // Lock access to the context registry. + ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); + + return impl->registeredOperations.count(name); +} + void Dialect::addOperation(AbstractOperation opInfo) { assert((getNamespace().empty() || opInfo.name.split('.').first == getNamespace()) && @@ -621,8 +628,9 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) { static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) { auto &impl = ctx->getImpl(); auto it = impl.registeredDialectSymbols.find(typeID); - assert(it != impl.registeredDialectSymbols.end() && - "symbol is not registered."); + if (it == impl.registeredDialectSymbols.end()) + llvm::report_fatal_error( + "Trying to create a type that was not registered in this MLIRContext."); return *it->second; } -- 2.7.4