Add a flag to Dialect that allows for dialects to enable support for unregistered...
authorRiver Riddle <riverriddle@google.com>
Mon, 1 Apr 2019 06:30:22 +0000 (23:30 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 1 Apr 2019 17:59:17 +0000 (10:59 -0700)
    Example:

    func @unknown_std_op() {
      %0 = "std.foo_bar_op"() : () -> index
      return
    }

    Will result in:

    error: unregistered operation 'std.foo_bar_op' found in dialect ('std') that does not allow unknown operations

--

PiperOrigin-RevId: 241266009

mlir/include/mlir/IR/Dialect.h
mlir/lib/Analysis/Verifier.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/test/IR/invalid-ops.mlir

index 77bb6b8..386499e 100644 (file)
@@ -49,6 +49,11 @@ public:
 
   StringRef getNamespace() const { return name; }
 
+  /// Returns true if this dialect allows for unregistered operations, i.e.
+  /// operations prefixed with the dialect namespace but not registered with
+  /// addOperation.
+  bool allowsUnknownOperations() const { return allowUnknownOps; }
+
   /// Registered fallback constant fold hook for the dialect. Like the constant
   /// fold hook of each operation, it attempts to constant fold the operation
   /// with the specified constant operand values - the elements in "operands"
@@ -188,6 +193,9 @@ protected:
   // Register a type with its given unqiue type identifer.
   void addType(const TypeID *const typeID);
 
+  // Enable support for unregistered operations.
+  void allowUnknownOperations(bool allow = true) { allowUnknownOps = allow; }
+
 private:
   Dialect(const Dialect &) = delete;
   void operator=(Dialect &) = delete;
@@ -201,6 +209,11 @@ private:
 
   /// This is the context that owns this Dialect object.
   MLIRContext *context;
+
+  /// Flag that toggles if this dialect supports unregistered operations, i.e.
+  /// operations prefixed with the dialect namespace but not registered with
+  /// addOperation.
+  bool allowUnknownOps;
 };
 
 using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
index fddd9ac..e189689 100644 (file)
@@ -70,9 +70,7 @@ public:
   }
 
   /// Returns the registered dialect for a dialect-specific attribute.
-  template <typename ErrorContext>
-  Dialect *getDialectForAttribute(const NamedAttribute &attr,
-                                  const ErrorContext &ctx) {
+  Dialect *getDialectForAttribute(const NamedAttribute &attr) {
     assert(attr.first.strref().contains('.') && "expected dialect attribute");
     auto dialectNamePair = attr.first.strref().split('.');
     return fn.getContext()->getRegisteredDialect(dialectNamePair.first);
@@ -124,6 +122,10 @@ private:
 
   /// Regex checker for attribute names.
   llvm::Regex identifierRegex;
+
+  /// Mapping between dialect namespace and if that dialect supports
+  /// unregistered operations.
+  llvm::StringMap<bool> dialectAllowsUnknownOps;
 };
 } // end anonymous namespace
 
@@ -149,7 +151,7 @@ bool FuncVerifier::verify() {
       return failure("functions may only have dialect attributes", fn);
 
     // Verify this attribute with the defining dialect.
-    if (auto *dialect = getDialectForAttribute(attr, fn))
+    if (auto *dialect = getDialectForAttribute(attr))
       if (dialect->verifyFunctionAttribute(&fn, attr))
         return true;
   }
@@ -172,7 +174,7 @@ bool FuncVerifier::verify() {
                        fn);
 
       // Verify this attribute with the defining dialect.
-      if (auto *dialect = getDialectForAttribute(attr, fn))
+      if (auto *dialect = getDialectForAttribute(attr))
         if (dialect->verifyFunctionArgAttribute(&fn, i, attr))
           return true;
     }
@@ -283,16 +285,15 @@ bool FuncVerifier::verifyOperation(Operation &op) {
     // Check for any optional dialect specific attributes.
     if (!attr.first.strref().contains('.'))
       continue;
-    if (auto *dialect = getDialectForAttribute(attr, op))
+    if (auto *dialect = getDialectForAttribute(attr))
       if (dialect->verifyOperationAttribute(&op, attr))
         return true;
   }
 
   // If we can get operation info for this, check the custom hook.
-  if (auto *opInfo = op.getAbstractOperation()) {
-    if (opInfo->verifyInvariants(&op))
-      return true;
-  }
+  auto *opInfo = op.getAbstractOperation();
+  if (opInfo && opInfo->verifyInvariants(&op))
+    return true;
 
   // Verify that all child blocks are ok.
   for (auto &region : op.getRegions())
@@ -300,6 +301,34 @@ bool FuncVerifier::verifyOperation(Operation &op) {
       if (verifyBlock(b, /*isTopLevel=*/false))
         return true;
 
+  // If this is a registered operation, there is nothing left to do.
+  if (opInfo)
+    return false;
+
+  // Otherwise, verify that the parent dialect allows un-registered operations.
+  auto opName = op.getName().getStringRef();
+  auto dialectPrefix = opName.split('.').first;
+
+  // Check for an existing answer for the operation dialect.
+  auto it = dialectAllowsUnknownOps.find(dialectPrefix);
+  if (it == dialectAllowsUnknownOps.end()) {
+    // If the operation dialect is registered, query it directly.
+    if (auto *dialect = fn.getContext()->getRegisteredDialect(dialectPrefix))
+      it = dialectAllowsUnknownOps
+               .try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
+               .first;
+    // Otherwise, conservatively allow unknown operations.
+    else
+      it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
+  }
+
+  if (!it->second) {
+    return failure("unregistered operation '" + opName +
+                       "' found in dialect ('" + dialectPrefix +
+                       "') that does not allow unknown operations",
+                   op);
+  }
+
   return false;
 }
 
index 5b84006..e87358c 100644 (file)
@@ -60,7 +60,7 @@ void mlir::registerAllDialects(MLIRContext *context) {
 }
 
 Dialect::Dialect(StringRef name, MLIRContext *context)
-    : name(name), context(context) {
+    : name(name), context(context), allowUnknownOps(false) {
   assert(isValidNamespace(name) && "invalid dialect namespace");
   registerDialect(context);
 }
index 4e05995..d9a7be0 100644 (file)
@@ -69,6 +69,9 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
 #define GET_OP_LIST
 #include "mlir/LLVMIR/LLVMOps.cpp.inc"
       >();
+
+  // Support unknown operations because not all LLVM operations are registered.
+  allowUnknownOperations();
 }
 
 #define GET_OP_CLASSES
index acc56ba..7d49b7f 100644 (file)
@@ -76,6 +76,14 @@ func @unknown_custom_op() {
 
 // -----
 
+func @unknown_std_op() {
+  // expected-error@+1 {{unregistered operation 'std.foo_bar_op' found in dialect ('std') that does not allow unknown operations}}
+  %0 = "std.foo_bar_op"() : () -> index
+  return
+}
+
+// -----
+
 func @bad_alloc_wrong_dynamic_dim_count() {
 ^bb0:
   %0 = "std.constant"() {value: 7} : () -> index