[mlir] Add a new builtin DenseResourceElementsAttr
authorRiver Riddle <riddleriver@gmail.com>
Wed, 20 Jul 2022 01:22:55 +0000 (18:22 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Mon, 1 Aug 2022 19:37:16 +0000 (12:37 -0700)
This attributes is intended cover the current set of use cases that abuse
DenseElementsAttr, e.g. when the data is large. Using resources for large
data is one of the major reasons why they were added; e.g. they can be
deallocated mid-compilation, they support a wide variety of data origins
(e.g, heap allocated, mmap'd, etc.), they can support mutation, etc.

I considered at length not having a builtin variant of this, and instead
having multiple versions of this attribute for dialects that are interested,
but they all boiled down to the exact same attribute definition. Given the
generality of this attribute, it feels more aligned to keep it next to DenseArrayAttr
(given that DenseArrayAttr covers the "small" case, and DenseResourcesElementsAttr
covers the "large" case). The underlying infra used to build this attribute is
general, and having a builtin attribute doesn't preclude users from defining
their own when it makes sense (they can even share a blob manager with the
builtin dialect to avoid data duplication).

Differential Revision: https://reviews.llvm.org/D130022

17 files changed:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/AsmParser/AsmParserImpl.h
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/AsmParser/Parser.cpp
mlir/lib/AsmParser/Parser.h
mlir/lib/AsmParser/TokenKinds.def
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/DialectResourceBlobManager.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/test/IR/dense-resource-elements-attr.mlir [new file with mode: 0644]
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/IR/invalid-file-metadata.mlir
mlir/unittests/IR/AttributeTest.cpp

index 7adec33..eb8f0ca 100644 (file)
 
 namespace mlir {
 class AffineMap;
+class AsmResourceBlob;
 class BoolAttr;
+class BuiltinDialect;
 class DenseIntElementsAttr;
+template <typename T>
+struct DialectResourceBlobHandle;
 class FlatSymbolRefAttr;
 class FunctionType;
 class IntegerSet;
@@ -729,6 +733,13 @@ public:
     return denseAttr && denseAttr.isSplat();
   }
 };
+
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+using DenseResourceElementsHandle = DialectResourceBlobHandle<BuiltinDialect>;
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
@@ -743,6 +754,9 @@ public:
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+
 namespace detail {
 /// Base class for DenseArrayAttr that is instantiated and specialized for each
 /// supported element type below.
@@ -796,6 +810,71 @@ using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
 using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
 
 //===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+
+namespace detail {
+/// Base class for DenseResourceElementsAttr that is instantiated and
+/// specialized for each supported element type below.
+template <typename T>
+class DenseResourceElementsAttrBase : public DenseResourceElementsAttr {
+public:
+  using DenseResourceElementsAttr::DenseResourceElementsAttr;
+
+  /// A builder that inserts a new resource using the provided blob. The handle
+  /// of the inserted blob is used when building the attribute. The provided
+  /// `blobName` is used as a hint for the key of the new handle for the `blob`
+  /// resource, but may be changed if necessary to ensure uniqueness during
+  /// insertion.
+  static DenseResourceElementsAttrBase<T>
+  get(ShapedType type, StringRef blobName, AsmResourceBlob blob);
+
+  /// Return the data of this attribute as an ArrayRef<T> if it is present,
+  /// returns None otherwise.
+  Optional<ArrayRef<T>> tryGetAsArrayRef() const;
+
+  /// Support for isa<>/cast<>.
+  static bool classof(Attribute attr);
+};
+
+extern template class DenseResourceElementsAttrBase<bool>;
+extern template class DenseResourceElementsAttrBase<int8_t>;
+extern template class DenseResourceElementsAttrBase<int16_t>;
+extern template class DenseResourceElementsAttrBase<int32_t>;
+extern template class DenseResourceElementsAttrBase<int64_t>;
+extern template class DenseResourceElementsAttrBase<uint8_t>;
+extern template class DenseResourceElementsAttrBase<uint16_t>;
+extern template class DenseResourceElementsAttrBase<uint32_t>;
+extern template class DenseResourceElementsAttrBase<uint64_t>;
+extern template class DenseResourceElementsAttrBase<float>;
+extern template class DenseResourceElementsAttrBase<double>;
+} // namespace detail
+
+// Public names for all the supported DenseResourceElementsAttr.
+
+using DenseBoolResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<bool>;
+using DenseI8ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<int8_t>;
+using DenseI16ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<int16_t>;
+using DenseI32ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<int32_t>;
+using DenseI64ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<int64_t>;
+using DenseUI8ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<uint8_t>;
+using DenseUI16ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<uint16_t>;
+using DenseUI32ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<uint32_t>;
+using DenseUI64ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<uint64_t>;
+using DenseF32ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<float>;
+using DenseF64ResourceElementsAttr =
+    detail::DenseResourceElementsAttrBase<double>;
+
+//===----------------------------------------------------------------------===//
 // BoolAttr
 //===----------------------------------------------------------------------===//
 
index 0b62090..7a771d8 100644 (file)
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SubElementInterfaces.td"
 
 // TODO: Currently the attributes defined in this file are prefixed with
@@ -425,6 +426,65 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
 }
 
 //===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
+    ElementsAttrInterface, TypedAttrInterface
+  ]> {
+  let summary = "An Attribute containing a dense multi-dimensional array "
+                "backed by a resource";
+  let description = [{
+    Syntax:
+
+    ```
+    dense-resource-elements-attribute ::=
+      `dense_resource` `<` resource-handle `>` `:` shaped-type
+    ```
+
+    A dense resource elements attribute is an elements attribute backed by a
+    handle to a builtin dialect resource containing a densely packed array of
+    values. This class provides the low-level attribute, which should only be
+    interacted with in very generic terms, actual access to the underlying
+    resource data is intended to be managed through one of the subclasses, such
+    as; `DenseBoolResourceElementsAttr`, `DenseUI64ResourceElementsAttr`,
+    `DenseI32ResourceElementsAttr`, `DenseF32ResourceElementsAttr`,
+    `DenseF64ResourceElementsAttr`, etc.
+
+    Examples:
+
+    ```mlir
+    // A tensor referencing a builtin dialect resource, `resource_1`, with two
+    // unsigned i32 elements.
+    dense_resource<resource_1> : tensor<2xui32>
+    ```
+  }];
+  let parameters = (ins
+    AttributeSelfTypeParameter<"", "ShapedType">:$type,
+    ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle
+  );
+  let builders = [
+    AttrBuilderWithInferredContext<(ins
+      "ShapedType":$type, "DenseResourceElementsHandle":$handle
+    )>
+  ];
+  let extraClassDeclaration = [{
+  protected:
+    /// A builder that inserts a new resource into the builtin dialect's blob
+    /// manager using the provided blob. The handle of the inserted blob is used
+    /// when building the attribute. The provided `blobName` is used as a hint
+    /// for the key of the new handle for the `blob` resource, but may be
+    /// changed if necessary to ensure uniqueness during insertion.
+    static DenseResourceElementsAttr get(
+      ShapedType type, StringRef blobName, AsmResourceBlob blob
+    );
+
+  public:
+  }];
+  let skipDefaultBuilders = 1;
+}
+
+//===----------------------------------------------------------------------===//
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
 
index a69e3b4..5de7888 100644 (file)
@@ -1023,8 +1023,17 @@ public:
   template <typename ResourceT>
   FailureOr<ResourceT> parseResourceHandle() {
     SMLoc handleLoc = getCurrentLocation();
-    FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
-        getContext()->getOrLoadDialect<typename ResourceT::Dialect>());
+
+    // Try to load the dialect that owns the handle.
+    auto *dialect =
+        getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
+    if (!dialect) {
+      return emitError(handleLoc)
+             << "dialect '" << ResourceT::Dialect::getDialectNamespace()
+             << "' is unknown";
+    }
+
+    FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
     if (failed(handle))
       return failure();
     if (auto *result = dyn_cast<ResourceT>(&*handle))
index c06eb68..5bc6c79 100644 (file)
@@ -460,7 +460,7 @@ public:
   /// Parse a handle to a resource within the assembly format.
   FailureOr<AsmDialectResourceHandle>
   parseResourceHandle(Dialect *dialect) override {
-    const auto *interface = dyn_cast_or_null<OpAsmDialectInterface>(dialect);
+    const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
     if (!interface) {
       return parser.emitError() << "dialect '" << dialect->getNamespace()
                                 << "' does not expect resource handles";
index dff8510..faa60b6 100644 (file)
 #include "AsmParserImpl.h"
 #include "mlir/AsmParser/AsmParserState.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Endian.h"
@@ -97,6 +98,10 @@ Attribute Parser::parseAttribute(Type type) {
   case Token::kw_dense:
     return parseDenseElementsAttr(type);
 
+  // Parse a dense resource elements attribute.
+  case Token::kw_dense_resource:
+    return parseDenseResourceElementsAttr(type);
+
   // Parse a dictionary attribute.
   case Token::l_brace: {
     NamedAttrList elements;
@@ -241,6 +246,7 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
   case Token::kw_affine_map:
   case Token::kw_affine_set:
   case Token::kw_dense:
+  case Token::kw_dense_resource:
   case Token::kw_false:
   case Token::kw_loc:
   case Token::kw_opaque:
@@ -928,6 +934,39 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
   return literalParser.getAttr(loc, type);
 }
 
+Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
+  auto loc = getToken().getLoc();
+  consumeToken(Token::kw_dense_resource);
+  if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
+    return nullptr;
+
+  // Parse the resource handle.
+  FailureOr<AsmDialectResourceHandle> rawHandle =
+      parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
+  if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
+    return nullptr;
+
+  auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
+  if (!handle)
+    return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
+
+  // Parse the type of the attribute if the user didn't provide one.
+  SMLoc typeLoc = loc;
+  if (!attrType) {
+    typeLoc = getToken().getLoc();
+    if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
+      return nullptr;
+  }
+
+  ShapedType shapedType = attrType.dyn_cast<ShapedType>();
+  if (!shapedType) {
+    emitError(typeLoc, "`dense_resource` expected a shaped type");
+    return nullptr;
+  }
+
+  return DenseResourceElementsAttr::get(shapedType, *handle);
+}
+
 /// Parse an opaque elements attribute.
 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
   SMLoc loc = getToken().getLoc();
index 9934ca5..6cc96e7 100644 (file)
@@ -340,6 +340,17 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
   return entry.second;
 }
 
+FailureOr<AsmDialectResourceHandle>
+Parser::parseResourceHandle(Dialect *dialect) {
+  const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
+  if (!interface) {
+    return emitError() << "dialect '" << dialect->getNamespace()
+                       << "' does not expect resource handles";
+  }
+  StringRef resourceName;
+  return parseResourceHandle(interface, resourceName);
+}
+
 //===----------------------------------------------------------------------===//
 // Code Completion
 
index 615f940..d48eeb9 100644 (file)
@@ -160,6 +160,7 @@ public:
   /// Parse a handle to a dialect resource within the assembly format.
   FailureOr<AsmDialectResourceHandle>
   parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
+  FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);
 
   //===--------------------------------------------------------------------===//
   // Type Parsing
@@ -272,6 +273,9 @@ public:
   Attribute parseDenseElementsAttr(Type attrType);
   ShapedType parseElementsLiteralType(Type type);
 
+  /// Parse a dense resource elements attribute.
+  Attribute parseDenseResourceElementsAttr(Type attrType);
+
   /// Parse a DenseArrayAttr.
   Attribute parseDenseArrayAttr();
 
index 207af38..f56e048 100644 (file)
@@ -87,6 +87,7 @@ TOK_KEYWORD(bf16)
 TOK_KEYWORD(ceildiv)
 TOK_KEYWORD(complex)
 TOK_KEYWORD(dense)
+TOK_KEYWORD(dense_resource)
 TOK_KEYWORD(f16)
 TOK_KEYWORD(f32)
 TOK_KEYWORD(f64)
index e5fd5ea..433fe22 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OpImplementation.h"
@@ -1896,6 +1897,10 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
       os << " ";
     denseArrayAttr.printWithoutBraces(os);
     os << "]";
+  } else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
+    os << "dense_resource<";
+    printResourceHandle(resourceAttr.getRawHandle());
+    os << ">";
   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
     printLocation(locAttr);
   } else {
index 021da17..ec19881 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
@@ -36,11 +37,10 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 void BuiltinDialect::registerAttributes() {
-  addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
-                DenseIntOrFPElementsAttr, DenseStringElementsAttr,
-                DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
-                IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
-                SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/IR/BuiltinAttributes.cpp.inc"
+      >();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1577,6 +1577,130 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
 }
 
 //===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseResourceElementsAttr
+DenseResourceElementsAttr::get(ShapedType type,
+                               DenseResourceElementsHandle handle) {
+  return Base::get(type.getContext(), type, handle);
+}
+
+DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
+                                                         StringRef blobName,
+                                                         AsmResourceBlob blob) {
+  // Extract the builtin dialect resource manager from context and construct a
+  // handle by inserting a new resource using the provided blob.
+  auto &manager =
+      DenseResourceElementsHandle::getManagerInterface(type.getContext());
+  return get(type, manager.insert(blobName, std::move(blob)));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttrBase
+
+namespace {
+/// Instantiations of this class provide utilities for interacting with native
+/// data types in the context of DenseResourceElementsAttr.
+template <typename T>
+struct DenseResourceAttrUtil;
+template <size_t width, bool isSigned>
+struct DenseResourceElementsAttrIntUtil {
+  static bool checkElementType(Type eltType) {
+    IntegerType type = eltType.dyn_cast<IntegerType>();
+    if (!type || type.getWidth() != width)
+      return false;
+    return isSigned ? !type.isUnsigned() : !type.isSigned();
+  }
+};
+template <>
+struct DenseResourceAttrUtil<bool> {
+  static bool checkElementType(Type eltType) {
+    return eltType.isSignlessInteger(1);
+  }
+};
+template <>
+struct DenseResourceAttrUtil<int8_t>
+    : public DenseResourceElementsAttrIntUtil<8, true> {};
+template <>
+struct DenseResourceAttrUtil<uint8_t>
+    : public DenseResourceElementsAttrIntUtil<8, false> {};
+template <>
+struct DenseResourceAttrUtil<int16_t>
+    : public DenseResourceElementsAttrIntUtil<16, true> {};
+template <>
+struct DenseResourceAttrUtil<uint16_t>
+    : public DenseResourceElementsAttrIntUtil<16, false> {};
+template <>
+struct DenseResourceAttrUtil<int32_t>
+    : public DenseResourceElementsAttrIntUtil<32, true> {};
+template <>
+struct DenseResourceAttrUtil<uint32_t>
+    : public DenseResourceElementsAttrIntUtil<32, false> {};
+template <>
+struct DenseResourceAttrUtil<int64_t>
+    : public DenseResourceElementsAttrIntUtil<64, true> {};
+template <>
+struct DenseResourceAttrUtil<uint64_t>
+    : public DenseResourceElementsAttrIntUtil<64, false> {};
+template <>
+struct DenseResourceAttrUtil<float> {
+  static bool checkElementType(Type eltType) { return eltType.isF32(); }
+};
+template <>
+struct DenseResourceAttrUtil<double> {
+  static bool checkElementType(Type eltType) { return eltType.isF64(); }
+};
+} // namespace
+
+template <typename T>
+DenseResourceElementsAttrBase<T>
+DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
+                                      AsmResourceBlob blob) {
+  // Check that the blob is in the form we were expecting.
+  assert(blob.getDataAlignment() == alignof(T) &&
+         "alignment mismatch between expected alignment and blob alignment");
+  assert(((blob.getData().size() % sizeof(T)) == 0) &&
+         "size mismatch between expected element width and blob size");
+  assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
+         "invalid shape element type for provided type `T`");
+  return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
+      .template cast<DenseResourceElementsAttrBase<T>>();
+}
+
+template <typename T>
+Optional<ArrayRef<T>>
+DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
+  if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
+    return blob->template getDataAs<T>();
+  return llvm::None;
+}
+
+template <typename T>
+bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
+  auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
+  return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
+                             resourceAttr.getElementType());
+}
+
+namespace mlir {
+namespace detail {
+// Explicit instantiation for all the supported DenseResourceElementsAttr.
+template class DenseResourceElementsAttrBase<bool>;
+template class DenseResourceElementsAttrBase<int8_t>;
+template class DenseResourceElementsAttrBase<int16_t>;
+template class DenseResourceElementsAttrBase<int32_t>;
+template class DenseResourceElementsAttrBase<int64_t>;
+template class DenseResourceElementsAttrBase<uint8_t>;
+template class DenseResourceElementsAttrBase<uint16_t>;
+template class DenseResourceElementsAttrBase<uint32_t>;
+template class DenseResourceElementsAttrBase<uint64_t>;
+template class DenseResourceElementsAttrBase<float>;
+template class DenseResourceElementsAttrBase<double>;
+} // namespace detail
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
 // OpaqueElementsAttr
 //===----------------------------------------------------------------------===//
 
index 662bcd8..7df22a9 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
-// Builtin Dialect
+// TableGen'erated dialect
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/BuiltinDialect.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// BuiltinBlobManagerInterface
+//===----------------------------------------------------------------------===//
+
+using BuiltinBlobManagerInterface =
+    ResourceBlobManagerDialectInterfaceBase<DenseResourceElementsHandle>;
+
+//===----------------------------------------------------------------------===//
+// BuiltinOpAsmDialectInterface
+//===----------------------------------------------------------------------===//
+
 namespace {
 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
-  using OpAsmDialectInterface::OpAsmDialectInterface;
+  BuiltinOpAsmDialectInterface(Dialect *dialect,
+                               BuiltinBlobManagerInterface &mgr)
+      : OpAsmDialectInterface(dialect), blobManager(mgr) {}
 
   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
     if (attr.isa<AffineMapAttr>()) {
@@ -57,6 +71,38 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
     }
     return AliasResult::NoAlias;
   }
+
+  //===------------------------------------------------------------------===//
+  // Resources
+  //===------------------------------------------------------------------===//
+
+  std::string
+  getResourceKey(const AsmDialectResourceHandle &handle) const override {
+    return cast<DenseResourceElementsHandle>(handle).getKey().str();
+  }
+  FailureOr<AsmDialectResourceHandle>
+  declareResource(StringRef key) const final {
+    return blobManager.insert(key);
+  }
+  LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
+    FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
+    if (failed(blob))
+      return failure();
+
+    // Update the blob for this entry.
+    blobManager.update(entry.getKey(), std::move(*blob));
+    return success();
+  }
+  void
+  buildResources(Operation *op,
+                 const SetVector<AsmDialectResourceHandle> &referencedResources,
+                 AsmResourceBuilder &provider) const final {
+    blobManager.buildResources(provider, referencedResources.getArrayRef());
+  }
+
+private:
+  /// The blob manager for the dialect.
+  BuiltinBlobManagerInterface &blobManager;
 };
 } // namespace
 
@@ -68,7 +114,9 @@ void BuiltinDialect::initialize() {
 #define GET_OP_LIST
 #include "mlir/IR/BuiltinOps.cpp.inc"
       >();
-  addInterfaces<BuiltinOpAsmDialectInterface>();
+
+  auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
+  addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
 }
 
 //===----------------------------------------------------------------------===//
index dbfe9c1..60a2fb2 100644 (file)
@@ -57,7 +57,7 @@ auto DialectResourceBlobManager::insert(StringRef name,
     Twine(nameCounter++).toVector(nameStorage);
 
     // Try inserting with the new name.
-    if (BlobEntry *entry = tryInsertion(name))
+    if (BlobEntry *entry = tryInsertion(nameStorage))
       return *entry;
     nameStorage.resize(name.size() + 1);
   } while (true);
index e338876..d0de010 100644 (file)
@@ -712,8 +712,9 @@ public:
 
   /// Signal a completion for an attribute.
   void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
-    appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
-                             "loc", "opaque", "sparse", "true", "unit"},
+    appendSimpleCompletions({"affine_set", "affine_map", "dense",
+                             "dense_resource", "false", "loc", "opaque",
+                             "sparse", "true", "unit"},
                             lsp::CompletionItemKind::Field,
                             /*sortText=*/"1");
 
diff --git a/mlir/test/IR/dense-resource-elements-attr.mlir b/mlir/test/IR/dense-resource-elements-attr.mlir
new file mode 100644 (file)
index 0000000..adba979
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK: attr = dense_resource<blob1> : tensor<3xi64>
+"test.user_op"() {attr = dense_resource<blob1> : tensor<3xi64> } : () -> ()
+
+{-#
+  dialect_resources: {
+    builtin: {
+      // CHECK: blob1: "0x08000000010000000000000002000000000000000300000000000000"
+      blob1: "0x08000000010000000000000002000000000000000300000000000000"
+    }
+  }
+#-}
index f6df53f..d996c22 100644 (file)
@@ -519,3 +519,23 @@ func.func @duplicate_dictionary_attr_key() {
 "J// -----
 
 "       // expected-error {{expected}}
+
+// -----
+
+// expected-error@+1 {{expected '<' after 'dense_resource'}}
+#attr = dense_resource>
+
+// -----
+
+// expected-error@+1 {{expected '>'}}
+#attr = dense_resource<resource
+
+// -----
+
+// expected-error@+1 {{expected ':'}}
+#attr = dense_resource<resource>
+
+// -----
+
+// expected-error@+1 {{`dense_resource` expected a shaped type}}
+#attr = dense_resource<resource> : i32
index 42f7b8e..352cf19 100644 (file)
 
 // -----
 
-// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}}
+// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'ml_program'}}
 {-#
   dialect_resources: {
-    builtin: {
+    ml_program: {
       unknown_entry: "foo"
     }
   }
index 82a1bcd..5611eac 100644 (file)
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "gtest/gtest.h"
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr
+//===----------------------------------------------------------------------===//
+
 template <typename EltTy>
 static void testSplat(Type eltType, const EltTy &splatElt) {
   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
@@ -203,7 +209,119 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
   auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
   EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
 }
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+template <typename AttrT, typename T>
+static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
+                              Type elementType) {
+  auto type = RankedTensorType::get(data.size(), elementType);
+  auto attr =
+      AttrT::get(type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+
+  // Check that we can access and iterate the data properly.
+  Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
+  EXPECT_TRUE(attrData.hasValue());
+  EXPECT_EQ(*attrData, data);
+
+  // Check that we cast to this attribute when possible.
+  Attribute genericAttr = attr;
+  EXPECT_TRUE(genericAttr.template isa<AttrT>());
+}
+template <typename AttrT, typename T>
+static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
+  T data[] = {0, 1, 2};
+  checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data),
+                              builder.getIntegerType(intWidth));
+}
+
+namespace {
+TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Bool
+  bool boolData[] = {true, false, true};
+  checkNativeAccess<DenseBoolResourceElementsAttr>(
+      &context, llvm::makeArrayRef(boolData), builder.getI1Type());
+
+  // Unsigned integers
+  checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
+  checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
+  checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
+  checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
+
+  // Signed integers
+  checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
+  checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
+  checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
+  checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
+
+  // Float
+  float floatData[] = {0, 1, 2};
+  checkNativeAccess<DenseF32ResourceElementsAttr>(
+      &context, llvm::makeArrayRef(floatData), builder.getF32Type());
+
+  // Double
+  double doubleData[] = {0, 1, 2};
+  checkNativeAccess<DenseF64ResourceElementsAttr>(
+      &context, llvm::makeArrayRef(doubleData), builder.getF64Type());
+}
+
+TEST(DenseResourceElementsAttrTest, CheckNoCast) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Create a i32 attribute.
+  ArrayRef<uint32_t> data;
+  auto type = RankedTensorType::get(data.size(), builder.getI32Type());
+  Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
+      type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+
+  EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
+  EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
+  EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
+}
 
+TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Create a bool attribute with data of the incorrect type.
+  ArrayRef<uint32_t> data;
+  auto type = RankedTensorType::get(data.size(), builder.getI32Type());
+  ASSERT_DEATH(
+      {
+        DenseBoolResourceElementsAttr::get(
+            type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+      },
+      "alignment mismatch between expected alignment and blob alignment");
+}
+
+TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  // Create a bool attribute with incorrect type.
+  ArrayRef<bool> data;
+  auto type = RankedTensorType::get(data.size(), builder.getI32Type());
+  ASSERT_DEATH(
+      {
+        DenseBoolResourceElementsAttr::get(
+            type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+      },
+      "invalid shape element type for provided type `T`");
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+namespace {
 TEST(SparseElementsAttrTest, GetZero) {
   MLIRContext context;
   context.allowUnregisteredDialects();