[mlir] Cleanup DenseArrayAttrBase definition and expose raw API
authorJeff Niu <jeff@modular.com>
Mon, 8 Aug 2022 22:04:05 +0000 (18:04 -0400)
committerJeff Niu <jeff@modular.com>
Tue, 9 Aug 2022 19:43:45 +0000 (15:43 -0400)
This patch cleans up the definition of `DenseArrayAttrBase` by relying
more on ODS-generated methods. It also exposes an API for using the raw
data of a dense array, similar to `DenseIntOrFPElementsAttr::getRaw`.

Reviewed By: lattner, mehdi_amini

Differential Revision: https://reviews.llvm.org/D131450

mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/IR/BuiltinAttributes.cpp

index e710d8d..74bdc35 100644 (file)
@@ -143,6 +143,18 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
 // DenseArrayBaseAttr
 //===----------------------------------------------------------------------===//
 
+def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
+    "char", "64-bit aligned storage for dense array elements"> {
+  let allocator = [{
+    if (!$_self.empty()) {
+      auto *alloc = static_cast<char *>(
+          $_allocator.allocate($_self.size(), alignof(uint64_t)));
+      std::uninitialized_copy($_self.begin(), $_self.end(), alloc);
+      $_dst = ArrayRef<char>(alloc, $_self.size());
+    }
+  }];
+}
+
 def Builtin_DenseArrayBase : Builtin_Attr<
     "DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
   let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
@@ -176,8 +188,8 @@ def Builtin_DenseArrayBase : Builtin_Attr<
     ```
   }];
   let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
-                        "DenseArrayBaseAttr::EltType":$eltType,
-                        ArrayRefParameter<"char">:$elements);
+                        "DenseArrayBaseAttr::EltType":$elementType,
+                        Builtin_DenseArrayRawDataParameter:$rawData);
   let extraClassDeclaration = [{
     // All possible supported element type.
     enum class EltType { I1, I8, I16, I32, I64, F32, F64 };
@@ -198,22 +210,12 @@ def Builtin_DenseArrayBase : Builtin_Attr<
     const float *value_begin_impl(OverloadToken<float>) const;
     const double *value_begin_impl(OverloadToken<double>) const;
 
-    /// Returns the shaped type, containing the number of elements in the array
-    /// and the array element type.
-    ShapedType getType() const;
-    /// Returns the element type.
-    EltType getElementType() const;
-
     /// Printer for the short form: will dispatch to the appropriate subclass.
     void print(AsmPrinter &printer) const;
     void print(raw_ostream &os) const;
     /// Print the short form `42, 100, -1` without any braces or prefix.
     void printWithoutBraces(raw_ostream &os) const;
   }];
-  // Do not generate the storage class, we need to handle custom storage alignment.
-  let genStorageClass = 0;
-  let genAccessors = 0;
-  let skipDefaultBuilders = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 334f20c..20c1b0a 100644 (file)
@@ -686,52 +686,6 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
 // DenseArrayAttr
 //===----------------------------------------------------------------------===//
 
-/// Custom storage to ensure proper memory alignment for the allocation of
-/// DenseArray of any element type.
-struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
-  using KeyTy =
-      std::tuple<ShapedType, DenseArrayBaseAttr::EltType, ArrayRef<char>>;
-  DenseArrayBaseAttrStorage(ShapedType type,
-                            DenseArrayBaseAttr::EltType eltType,
-                            ArrayRef<char> elements)
-      : type(type), eltType(eltType), elements(elements) {}
-
-  bool operator==(const KeyTy &key) const {
-    return (type == std::get<0>(key)) && (eltType == std::get<1>(key)) &&
-           (elements == std::get<2>(key));
-  }
-
-  static llvm::hash_code hashKey(const KeyTy &key) {
-    return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
-                              std::get<2>(key));
-  }
-
-  static DenseArrayBaseAttrStorage *
-  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
-    auto type = std::get<0>(key);
-    auto eltType = std::get<1>(key);
-    auto elements = std::get<2>(key);
-    if (!elements.empty()) {
-      char *alloc = static_cast<char *>(
-          allocator.allocate(elements.size(), alignof(uint64_t)));
-      std::uninitialized_copy(elements.begin(), elements.end(), alloc);
-      elements = ArrayRef<char>(alloc, elements.size());
-    }
-    return new (allocator.allocate<DenseArrayBaseAttrStorage>())
-        DenseArrayBaseAttrStorage(type, eltType, elements);
-  }
-
-  ShapedType type;
-  DenseArrayBaseAttr::EltType eltType;
-  ArrayRef<char> elements;
-};
-
-DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
-  return getImpl()->eltType;
-}
-
-ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; }
-
 const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
   return cast<DenseBoolArrayAttr>().asArrayRef().begin();
 }
@@ -880,7 +834,7 @@ Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
 template <typename T>
 DenseArrayAttr<T>::operator ArrayRef<T>() const {
-  ArrayRef<char> raw = getImpl()->elements;
+  ArrayRef<char> raw = getRawData();
   assert((raw.size() % sizeof(T)) == 0);
   return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
                      raw.size() / sizeof(T));