Improve support for opaque types in MLIR, allowing dialects to opt into
authorChris Lattner <clattner@google.com>
Wed, 7 Aug 2019 18:49:56 +0000 (11:49 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 7 Aug 2019 18:50:26 +0000 (11:50 -0700)
supporting opaque types, and providing ODS support for matching them.

PiperOrigin-RevId: 262183028

mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/TypeUtilities.h
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/TypeUtilities.cpp

index 84a0331..eef7711 100644 (file)
@@ -57,7 +57,12 @@ public:
   /// Returns true if this dialect allows for unregistered operations, i.e.
   /// operations prefixed with the dialect namespace but not registered with
   /// addOperation.
-  bool allowsUnknownOperations() const { return allowUnknownOps; }
+  bool allowsUnknownOperations() const { return unknownOpsAllowed; }
+
+  /// Return true if this dialect allows for unregistered types, i.e., types
+  /// prefixed with the dialect namespace but not registered with addType.
+  /// These are represented with OpaqueType.
+  bool allowsUnknownTypes() const { return unknownTypesAllowed; }
 
   //===--------------------------------------------------------------------===//
   // Constant Hooks
@@ -226,8 +231,11 @@ protected:
     }
   };
 
-  // Enable support for unregistered operations.
-  void allowUnknownOperations(bool allow = true) { allowUnknownOps = allow; }
+  /// Enable support for unregistered operations.
+  void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
+
+  /// Enable support for unregistered types.
+  void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
 
 private:
   // Register a symbol(e.g. type) with its given unique class identifier.
@@ -246,10 +254,15 @@ private:
   /// This is the context that owns this Dialect object.
   MLIRContext *context;
 
-  /// Flag that toggles if this dialect supports unregistered operations, i.e.
-  /// operations prefixed with the dialect namespace but not registered with
-  /// addOperation.
-  bool allowUnknownOps;
+  /// Flag that specifies whether this dialect supports unregistered operations,
+  /// i.e. operations prefixed with the dialect namespace but not registered
+  /// with addOperation.
+  bool unknownOpsAllowed = false;
+
+  /// Flag that specifies whether this dialect allows unregistered types, i.e.
+  /// types prefixed with the dialect namespace but not registered with addType.
+  /// These types are represented with OpaqueType.
+  bool unknownTypesAllowed = false;
 };
 
 using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
index a4b815c..3cf3efc 100644 (file)
@@ -334,6 +334,10 @@ def F64 : F<64>;
 def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
            BuildableType<"getBF16Type()">;
 
+class OpaqueType<string dialect, string name, string description>
+  : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
+         description>;
+
 // Function Type
 
 // Any function type.
index 5d56d5b..ce0169f 100644 (file)
@@ -48,6 +48,10 @@ Type getElementTypeOrSelf(Value &val);
 /// handles storage concerns, which is tricky to do in tablegen.
 SmallVector<Type, 10> getFlattenedTypes(TupleType t);
 
+/// Return true if the specified type is an opaque type with the specified
+/// dialect and typeData.
+bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData);
+
 //===----------------------------------------------------------------------===//
 // Utility Iterators
 //===----------------------------------------------------------------------===//
index 17dea1f..1170e06 100644 (file)
@@ -62,7 +62,7 @@ void mlir::registerAllDialects(MLIRContext *context) {
 }
 
 Dialect::Dialect(StringRef name, MLIRContext *context)
-    : name(name), context(context), allowUnknownOps(false) {
+    : name(name), context(context) {
   assert(isValidNamespace(name) && "invalid dialect namespace");
   registerDialect(context);
 }
@@ -88,6 +88,12 @@ Attribute Dialect::parseAttribute(StringRef attrData, Type type,
 
 /// Parse a type registered to this dialect.
 Type Dialect::parseType(StringRef tyData, Location loc) const {
+  // If this dialect allows unknown types, then represent this with OpaqueType.
+  if (allowsUnknownTypes()) {
+    auto ns = Identifier::get(getNamespace(), getContext());
+    return OpaqueType::get(ns, tyData, getContext());
+  }
+
   emitError(loc) << "dialect '" << getNamespace()
                  << "' provides no type parsing hook";
   return Type();
index 63543f4..95895af 100644 (file)
@@ -51,6 +51,16 @@ SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
   return fTypes;
 }
 
+/// Return true if the specified type is an opaque type with the specified
+/// dialect and typeData.
+bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
+                                StringRef typeData) {
+  if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
+    return opaque.getDialectNamespace().is(dialect) &&
+           opaque.getTypeData() == typeData;
+  return false;
+}
+
 OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
     : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}