Remove the explicit attribute kinds for DenseIntElementsAttr and DenseFPElementsAttr...
authorRiver Riddle <riverriddle@google.com>
Thu, 6 Jun 2019 23:15:42 +0000 (16:15 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:22:05 +0000 (16:22 -0700)
PiperOrigin-RevId: 251948820

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp

index cedd181..9adcd6c 100644 (file)
@@ -154,8 +154,7 @@ enum Kind {
   Function,
 
   SplatElements,
-  DenseIntElements,
-  DenseFPElements,
+  DenseElements,
   OpaqueElements,
   SparseElements,
   FIRST_ELEMENTS_ATTR = SplatElements,
@@ -497,10 +496,11 @@ public:
 
 /// An attribute that represents a reference to a dense vector or tensor object.
 ///
-class DenseElementsAttr : public ElementsAttr {
+class DenseElementsAttr
+    : public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
+                                 detail::DenseElementsAttributeStorage> {
 public:
-  using ElementsAttr::ElementsAttr;
-  using ImplType = detail::DenseElementsAttributeStorage;
+  using Base::Base;
 
   /// It assumes the elements in the input array have been truncated to the bits
   /// width specified by the element type. 'type' must be a vector or tensor
@@ -547,8 +547,7 @@ public:
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr) {
-    return attr.getKind() == StandardAttributes::DenseIntElements ||
-           attr.getKind() == StandardAttributes::DenseFPElements;
+    return attr.getKind() == StandardAttributes::DenseElements;
   }
 
 protected:
@@ -609,15 +608,13 @@ protected:
 
 /// An attribute that represents a reference to a dense integer vector or tensor
 /// object.
-class DenseIntElementsAttr
-    : public Attribute::AttrBase<DenseIntElementsAttr, DenseElementsAttr,
-                                 detail::DenseElementsAttributeStorage> {
+class DenseIntElementsAttr : public DenseElementsAttr {
 public:
   /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
   /// iterator directly.
   using iterator = DenseElementsAttr::RawElementIterator;
 
-  using Base::Base;
+  using DenseElementsAttr::DenseElementsAttr;
   using DenseElementsAttr::get;
   using DenseElementsAttr::getValues;
 
@@ -645,17 +642,13 @@ public:
   iterator begin() const { return raw_begin(); }
   iterator end() const { return raw_end(); }
 
-  /// Method for support type inquiry through isa, cast and dyn_cast.
-  static bool kindof(unsigned kind) {
-    return kind == StandardAttributes::DenseIntElements;
-  }
+  /// Method for supporting type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr);
 };
 
 /// An attribute that represents a reference to a dense float vector or tensor
 /// object. Each element is stored as a double.
-class DenseFPElementsAttr
-    : public Attribute::AttrBase<DenseFPElementsAttr, DenseElementsAttr,
-                                 detail::DenseElementsAttributeStorage> {
+class DenseFPElementsAttr : public DenseElementsAttr {
 public:
   /// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
   /// element iterator.
@@ -669,7 +662,7 @@ public:
   };
   using iterator = ElementIterator;
 
-  using Base::Base;
+  using DenseElementsAttr::DenseElementsAttr;
   using DenseElementsAttr::get;
   using DenseElementsAttr::getValues;
 
@@ -692,10 +685,8 @@ public:
   iterator begin() const;
   iterator end() const;
 
-  /// Method for support type inquiry through isa, cast and dyn_cast.
-  static bool kindof(unsigned kind) {
-    return kind == StandardAttributes::DenseFPElements;
-  }
+  /// Method for supporting type inquiry through isa, cast and dyn_cast.
+  static bool classof(Attribute attr);
 };
 
 /// An opaque attribute that represents a reference to a vector or tensor
index 239d58d..145429d 100644 (file)
@@ -692,8 +692,7 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
     os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
     break;
   }
-  case StandardAttributes::DenseIntElements:
-  case StandardAttributes::DenseFPElements: {
+  case StandardAttributes::DenseElements: {
     auto eltsAttr = attr.cast<DenseElementsAttr>();
     os << "dense<";
     printType(eltsAttr.getType());
index 0e2b5a1..b5c965d 100644 (file)
@@ -367,8 +367,7 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   switch (getKind()) {
   case StandardAttributes::SplatElements:
     return cast<SplatElementsAttr>().getValue();
-  case StandardAttributes::DenseFPElements:
-  case StandardAttributes::DenseIntElements:
+  case StandardAttributes::DenseElements:
     return cast<DenseElementsAttr>().getValue(index);
   case StandardAttributes::OpaqueElements:
     return cast<OpaqueElementsAttr>().getValue(index);
@@ -383,8 +382,7 @@ ElementsAttr ElementsAttr::mapValues(
     Type newElementType,
     llvm::function_ref<APInt(const APInt &)> mapping) const {
   switch (getKind()) {
-  case StandardAttributes::DenseIntElements:
-  case StandardAttributes::DenseFPElements:
+  case StandardAttributes::DenseElements:
     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
   case StandardAttributes::SplatElements:
     return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
@@ -397,8 +395,7 @@ ElementsAttr ElementsAttr::mapValues(
     Type newElementType,
     llvm::function_ref<APInt(const APFloat &)> mapping) const {
   switch (getKind()) {
-  case StandardAttributes::DenseIntElements:
-  case StandardAttributes::DenseFPElements:
+  case StandardAttributes::DenseElements:
     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
   case StandardAttributes::SplatElements:
     return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
@@ -542,19 +539,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
   assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
          "type must be ranked tensor or vector");
   assert(type.hasStaticShape() && "type must have static shape");
-  switch (type.getElementType().getKind()) {
-  case StandardTypes::BF16:
-  case StandardTypes::F16:
-  case StandardTypes::F32:
-  case StandardTypes::F64:
-    return AttributeUniquer::get<DenseFPElementsAttr>(
-        type.getContext(), StandardAttributes::DenseFPElements, type, data);
-  case StandardTypes::Integer:
-    return AttributeUniquer::get<DenseIntElementsAttr>(
-        type.getContext(), StandardAttributes::DenseIntElements, type, data);
-  default:
-    llvm_unreachable("unexpected element type");
-  }
+  return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
+                   data);
 }
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -631,22 +617,17 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
       readBits(getRawData().data(), valueIndex * storageBitWidth, bitWidth);
 
   // Convert the raw value data to an attribute value.
-  switch (getKind()) {
-  case StandardAttributes::DenseIntElements:
+  if (elementType.isa<IntegerType>())
     return IntegerAttr::get(elementType, rawValueData);
-  case StandardAttributes::DenseFPElements:
-    return FloatAttr::get(
-        elementType, APFloat(elementType.cast<FloatType>().getFloatSemantics(),
-                             rawValueData));
-  default:
-    llvm_unreachable("unexpected element type");
-  }
+  if (auto fType = elementType.dyn_cast<FloatType>())
+    return FloatAttr::get(elementType,
+                          APFloat(fType.getFloatSemantics(), rawValueData));
+  llvm_unreachable("unexpected element type");
 }
 
 void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
   auto elementType = getType().getElementType();
-  switch (getKind()) {
-  case StandardAttributes::DenseIntElements: {
+  if (elementType.isa<IntegerType>()) {
     // Get the raw APInt values.
     SmallVector<APInt, 8> intValues;
     cast<DenseIntElementsAttr>().getValues(intValues);
@@ -656,7 +637,7 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
       values.push_back(IntegerAttr::get(elementType, intVal));
     return;
   }
-  case StandardAttributes::DenseFPElements: {
+  if (elementType.isa<FloatType>()) {
     // Get the raw APFloat values.
     SmallVector<APFloat, 8> floatValues;
     cast<DenseFPElementsAttr>().getValues(floatValues);
@@ -666,9 +647,7 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
       values.push_back(FloatAttr::get(elementType, floatVal));
     return;
   }
-  default:
-    llvm_unreachable("unexpected element type");
-  }
+  llvm_unreachable("unexpected element type");
 }
 
 DenseElementsAttr DenseElementsAttr::mapValues(
@@ -810,6 +789,12 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
   return get(newArrayType, elementData);
 }
 
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseIntElementsAttr::classof(Attribute attr) {
+  return attr.isa<DenseElementsAttr>() &&
+         attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
+}
+
 //===----------------------------------------------------------------------===//
 // DenseFPElementsAttr
 //===----------------------------------------------------------------------===//
@@ -859,6 +844,12 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
   return {elementSemantics, raw_end()};
 }
 
+/// Method for supporting type inquiry through isa, cast and dyn_cast.
+bool DenseFPElementsAttr::classof(Attribute attr) {
+  return attr.isa<DenseElementsAttr>() &&
+         attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
+}
+
 //===----------------------------------------------------------------------===//
 // OpaqueElementsAttr
 //===----------------------------------------------------------------------===//
index 375ba55..374185e 100644 (file)
@@ -135,9 +135,9 @@ namespace {
 /// the IR.
 struct BuiltinDialect : public Dialect {
   BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
-    addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseIntElementsAttr,
-                  DenseFPElementsAttr, DictionaryAttr, FloatAttr, FunctionAttr,
-                  IntegerAttr, IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
+    addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
+                  DictionaryAttr, FloatAttr, FunctionAttr, IntegerAttr,
+                  IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
                   SparseElementsAttr, SplatElementsAttr, StringAttr, TypeAttr,
                   UnitAttr>();
     addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
index 3c8642e..e0da9ed 100644 (file)
@@ -116,7 +116,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
   auto expectedTensorType = realValue.getType().cast<TensorType>();
   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
   EXPECT_EQ(tensorType.getElementType(), convertedType);
-  EXPECT_EQ(returnedValue.getKind(), StandardAttributes::DenseIntElements);
+  EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
 
   // Check Elements attribute element value is expected.
   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});