Ensure that every Attribute contains a Type. If an Attribute does not provide...
authorRiver Riddle <riverriddle@google.com>
Tue, 30 Apr 2019 21:26:04 +0000 (14:26 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:23:14 +0000 (08:23 -0700)
--

PiperOrigin-RevId: 246021088

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/StandardOps/Ops.cpp

index 5deb2bb..9d09c38 100644 (file)
@@ -47,7 +47,6 @@ struct AffineMapAttributeStorage;
 struct IntegerSetAttributeStorage;
 struct TypeAttributeStorage;
 struct FunctionAttributeStorage;
-struct ElementsAttributeStorage;
 struct SplatElementsAttributeStorage;
 struct DenseElementsAttributeStorage;
 struct DenseIntElementsAttributeStorage;
@@ -125,6 +124,9 @@ public:
   /// Return the classification for this attribute.
   Kind getKind() const;
 
+  /// Return the type of this attribute.
+  Type getType() const;
+
   /// Return true if this field is, or contains, a function attribute.
   bool isOrContainsFunction() const;
 
@@ -177,8 +179,6 @@ class NumericAttr : public Attribute {
 public:
   using Attribute::Attribute;
 
-  Type getType() const;
-
   static bool kindof(Kind kind);
 };
 
@@ -192,8 +192,6 @@ public:
 
   bool getValue() const;
 
-  Type getType() const;
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) { return kind == Kind::Bool; }
 };
@@ -211,8 +209,6 @@ public:
   // TODO(jpienaar): Change callers to use getValue instead.
   int64_t getInt() const;
 
-  Type getType() const;
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) { return kind == Kind::Integer; }
 };
@@ -238,8 +234,6 @@ public:
   /// precision.
   double getValueAsDouble() const;
 
-  Type getType() const;
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) { return kind == Kind::Float; }
 };
@@ -353,7 +347,6 @@ public:
 class ElementsAttr : public NumericAttr {
 public:
   using NumericAttr::NumericAttr;
-  using ImplType = detail::ElementsAttributeStorage;
 
   VectorOrTensorType getType() const;
 
index 89ac240..3db5ec8 100644 (file)
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Support/StorageUniquer.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/PointerIntPair.h"
 #include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
 namespace detail {
-/// Base storage class appearing in an attribute.
+/// Base storage class appearing in an attribute. Derived storage classes should
+/// only be constructed within the context of the AttributeUniquer.
 struct AttributeStorage : public StorageUniquer::BaseStorage {
-  AttributeStorage(bool isOrContainsFunctionCache = false)
-      : isOrContainsFunctionCache(isOrContainsFunctionCache) {}
-
-  /// This field is true if this is, or contains, a function attribute.
-  bool isOrContainsFunctionCache : 1;
+  /// Construct a new attribute storage instance with the given type and a
+  /// boolean that signals if the derived attribute is or contains a function
+  /// pointer.
+  /// Note: All attributes require a valid type. If a null type is provided
+  ///       here, the type of the attribute will automatically default to
+  ///       NoneType upon initialization in the uniquer.
+  AttributeStorage(Type type = {}, bool isOrContainsFunctionCache = false)
+      : typeAndContainsFunctionAttrPair(type, isOrContainsFunctionCache) {}
+  AttributeStorage(bool isOrContainsFunctionCache)
+      : AttributeStorage(/*type=*/{}, isOrContainsFunctionCache) {}
+
+  bool isOrContainsFunctionCache() const {
+    return typeAndContainsFunctionAttrPair.getInt();
+  }
+
+  Type getType() const { return typeAndContainsFunctionAttrPair.getPointer(); }
+  void setType(Type type) { typeAndContainsFunctionAttrPair.setPointer(type); }
+
+  /// This field is a pair of:
+  ///  - The type of the attribute value.
+  ///  - A boolean that is true if this is, or contains, a function attribute.
+  llvm::PointerIntPair<Type, 1, bool> typeAndContainsFunctionAttrPair;
 };
 
 // A utility class to get, or create, unique instances of attributes within an
@@ -54,7 +74,7 @@ public:
       !std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
   get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
     return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
-        /*initFn=*/{}, static_cast<unsigned>(kind),
+        getInitFn(ctx), static_cast<unsigned>(kind),
         std::forward<Args>(args)...);
   }
 
@@ -66,7 +86,7 @@ public:
       std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
   get(MLIRContext *ctx, Attribute::Kind kind) {
     return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
-        /*initFn=*/{}, static_cast<unsigned>(kind));
+        getInitFn(ctx), static_cast<unsigned>(kind));
   }
 
   /// Erase a uniqued instance of attribute T. This overload is used for
@@ -78,6 +98,15 @@ public:
     return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
         static_cast<unsigned>(kind), std::forward<Args>(args)...);
   }
+
+  /// Generate a functor to initialize a new attribute storage instance.
+  static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx) {
+    return [ctx](AttributeStorage *storage) {
+      // If the attribute did not provide a type, then default to NoneType.
+      if (!storage->getType())
+        storage->setType(NoneType::get(ctx));
+    };
+  }
 };
 
 using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
@@ -86,7 +115,8 @@ using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
 struct BoolAttributeStorage : public AttributeStorage {
   using KeyTy = std::pair<MLIRContext *, bool>;
 
-  BoolAttributeStorage(Type type, bool value) : type(type), value(value) {}
+  BoolAttributeStorage(Type type, bool value)
+      : AttributeStorage(type), value(value) {}
 
   /// We only check equality for and hash with the boolean key parameter.
   bool operator==(const KeyTy &key) const { return key.second == value; }
@@ -100,7 +130,6 @@ struct BoolAttributeStorage : public AttributeStorage {
         BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
   }
 
-  Type type;
   bool value;
 };
 
@@ -111,13 +140,13 @@ struct IntegerAttributeStorage final
   using KeyTy = std::pair<Type, APInt>;
 
   IntegerAttributeStorage(Type type, size_t numObjects)
-      : type(type), numObjects(numObjects) {
+      : AttributeStorage(type), numObjects(numObjects) {
     assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
   }
 
   /// Key equality and hash functions.
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(type, getValue());
+    return key == KeyTy(getType(), getValue());
   }
   static unsigned hashKey(const KeyTy &key) {
     return llvm::hash_combine(key.first, llvm::hash_value(key.second));
@@ -142,13 +171,12 @@ struct IntegerAttributeStorage final
 
   /// Returns an APInt representing the stored value.
   APInt getValue() const {
-    if (type.isIndex())
+    if (getType().isIndex())
       return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
-    return APInt(type.getIntOrFloatBitWidth(),
+    return APInt(getType().getIntOrFloatBitWidth(),
                  {getTrailingObjects<uint64_t>(), numObjects});
   }
 
-  Type type;
   size_t numObjects;
 };
 
@@ -160,12 +188,11 @@ struct FloatAttributeStorage final
 
   FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
                         size_t numObjects)
-      : semantics(semantics), type(type.cast<FloatType>()),
-        numObjects(numObjects) {}
+      : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
 
   /// Key equality and hash functions.
   bool operator==(const KeyTy &key) const {
-    return key.first == type && key.second.bitwiseIsEqual(getValue());
+    return key.first == getType() && key.second.bitwiseIsEqual(getValue());
   }
   static unsigned hashKey(const KeyTy &key) {
     return llvm::hash_combine(key.first, llvm::hash_value(key.second));
@@ -197,7 +224,6 @@ struct FloatAttributeStorage final
   }
 
   const llvm::fltSemantics &semantics;
-  FloatType type;
   size_t numObjects;
 };
 
@@ -307,7 +333,8 @@ struct FunctionAttributeStorage : public AttributeStorage {
   using KeyTy = Function *;
 
   FunctionAttributeStorage(Function *value)
-      : AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {}
+      : AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
+        value(value) {}
 
   /// Key equality function.
   bool operator==(const KeyTy &key) const { return key == value; }
@@ -329,26 +356,17 @@ struct FunctionAttributeStorage : public AttributeStorage {
   Function *value;
 };
 
-/// A base attribute representing a reference to a vector or tensor constant.
-struct ElementsAttributeStorage : public AttributeStorage {
-  ElementsAttributeStorage(VectorOrTensorType type) : type(type) {}
-  VectorOrTensorType type;
-};
-
 /// An attribute representing a reference to a vector or tensor constant,
 /// inwhich all elements have the same value.
-struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
-  using KeyTy = std::pair<VectorOrTensorType, Attribute>;
+struct SplatElementsAttributeStorage : public AttributeStorage {
+  using KeyTy = std::pair<Type, Attribute>;
 
-  SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt)
-      : ElementsAttributeStorage(type), elt(elt) {}
+  SplatElementsAttributeStorage(Type type, Attribute elt)
+      : AttributeStorage(type), elt(elt) {}
 
   /// Key equality and hash functions.
   bool operator==(const KeyTy &key) const {
-    return key == std::make_pair(type, elt);
-  }
-  static unsigned hashKey(const KeyTy &key) {
-    return llvm::hash_combine(key.first, key.second);
+    return key == std::make_pair(getType(), elt);
   }
 
   /// Construct a new storage instance.
@@ -362,16 +380,15 @@ struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
 };
 
 /// An attribute representing a reference to a dense vector or tensor object.
-struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
-  using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
+struct DenseElementsAttributeStorage : public AttributeStorage {
+  using KeyTy = std::pair<Type, ArrayRef<char>>;
 
-  DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef<char> data)
-      : ElementsAttributeStorage(ty), data(data) {}
+  DenseElementsAttributeStorage(Type ty, ArrayRef<char> data)
+      : AttributeStorage(ty), data(data) {}
 
   /// Key equality and hash functions.
-  bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); }
-  static unsigned hashKey(const KeyTy &key) {
-    return llvm::hash_combine(key.first, key.second);
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getType(), data);
   }
 
   /// Construct a new storage instance.
@@ -398,16 +415,15 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
 
 /// An attribute representing a reference to a tensor constant with opaque
 /// content.
-struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
-  using KeyTy = std::tuple<VectorOrTensorType, Dialect *, StringRef>;
+struct OpaqueElementsAttributeStorage : public AttributeStorage {
+  using KeyTy = std::tuple<Type, Dialect *, StringRef>;
 
-  OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect,
-                                 StringRef bytes)
-      : ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {}
+  OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
+      : AttributeStorage(type), dialect(dialect), bytes(bytes) {}
 
   /// Key equality and hash functions.
   bool operator==(const KeyTy &key) const {
-    return key == std::make_tuple(type, dialect, bytes);
+    return key == std::make_tuple(getType(), dialect, bytes);
   }
   static unsigned hashKey(const KeyTy &key) {
     return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
@@ -429,18 +445,16 @@ struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
 };
 
 /// An attribute representing a reference to a sparse vector or tensor object.
-struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
-  using KeyTy =
-      std::tuple<VectorOrTensorType, DenseIntElementsAttr, DenseElementsAttr>;
+struct SparseElementsAttributeStorage : public AttributeStorage {
+  using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
 
-  SparseElementsAttributeStorage(VectorOrTensorType type,
-                                 DenseIntElementsAttr indices,
+  SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
                                  DenseElementsAttr values)
-      : ElementsAttributeStorage(type), indices(indices), values(values) {}
+      : AttributeStorage(type), indices(indices), values(values) {}
 
   /// Key equality and hash functions.
   bool operator==(const KeyTy &key) const {
-    return key == std::make_tuple(type, indices, values);
+    return key == std::make_tuple(getType(), indices, values);
   }
   static unsigned hashKey(const KeyTy &key) {
     return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
index 4df46df..aa1127c 100644 (file)
@@ -30,8 +30,11 @@ Attribute::Kind Attribute::getKind() const {
   return static_cast<Kind>(attr->getKind());
 }
 
+/// Return the type of this attribute.
+Type Attribute::getType() const { return attr->getType(); }
+
 bool Attribute::isOrContainsFunction() const {
-  return attr->isOrContainsFunctionCache;
+  return attr->isOrContainsFunctionCache();
 }
 
 // Given an attribute that could refer to a function attribute in the remapping
@@ -79,19 +82,6 @@ UnitAttr UnitAttr::get(MLIRContext *context) {
 // NumericAttr
 //===----------------------------------------------------------------------===//
 
-Type NumericAttr::getType() const {
-  if (auto boolAttr = dyn_cast<BoolAttr>())
-    return boolAttr.getType();
-  if (auto intAttr = dyn_cast<IntegerAttr>())
-    return intAttr.getType();
-  if (auto floatAttr = dyn_cast<FloatAttr>())
-    return floatAttr.getType();
-  if (auto elemAttr = dyn_cast<ElementsAttr>())
-    return elemAttr.getType();
-
-  llvm_unreachable("unhandled NumericAttr subclass");
-}
-
 bool NumericAttr::kindof(Kind kind) {
   return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) ||
          FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
@@ -109,8 +99,6 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
 
 bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 
-Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
-
 //===----------------------------------------------------------------------===//
 // IntegerAttr
 //===----------------------------------------------------------------------===//
@@ -135,10 +123,6 @@ APInt IntegerAttr::getValue() const {
 
 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
 
-Type IntegerAttr::getType() const {
-  return static_cast<ImplType *>(attr)->type;
-}
-
 //===----------------------------------------------------------------------===//
 // FloatAttr
 //===----------------------------------------------------------------------===//
@@ -185,8 +169,6 @@ APFloat FloatAttr::getValue() const {
   return static_cast<ImplType *>(attr)->getValue();
 }
 
-Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
-
 double FloatAttr::getValueAsDouble() const {
   const auto &semantics = getType().cast<FloatType>().getFloatSemantics();
   auto value = getValue();
@@ -281,14 +263,16 @@ Function *FunctionAttr::getValue() const {
   return static_cast<ImplType *>(attr)->value;
 }
 
-FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
+FunctionType FunctionAttr::getType() const {
+  return Attribute::getType().cast<FunctionType>();
+}
 
 //===----------------------------------------------------------------------===//
 // ElementsAttr
 //===----------------------------------------------------------------------===//
 
 VectorOrTensorType ElementsAttr::getType() const {
-  return static_cast<ImplType *>(attr)->type;
+  return Attribute::getType().cast<VectorOrTensorType>();
 }
 
 /// Return the value at the given index. If index does not refer to a valid
@@ -315,8 +299,8 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
 
 SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
                                          Attribute elt) {
-  assert(elt.cast<NumericAttr>().getType() == type.getElementType() &&
-         "value should be of the given type");
+  assert(elt.getType() == type.getElementType() &&
+         "value should be of the given element type");
   return AttributeUniquer::get<SplatElementsAttr>(
       type.getContext(), Attribute::Kind::SplatElements, type, elt);
 }
index 3aa35cd..a05d83d 100644 (file)
@@ -920,24 +920,11 @@ void ConstantOp::build(Builder *builder, OperationState *result, Type type,
   result->types.push_back(type);
 }
 
-// Extracts and returns a type of an attribute if it has one.  Returns a null
-// type otherwise.  Currently, NumericAttrs and FunctionAttrs have types.
-static Type getAttributeType(Attribute attr) {
-  assert(attr && "expected non-null attribute");
-  if (auto numericAttr = attr.dyn_cast<NumericAttr>())
-    return numericAttr.getType();
-  if (auto functionAttr = attr.dyn_cast<FunctionAttr>())
-    return functionAttr.getType();
-  return {};
-}
-
 /// Builds a constant with the specified attribute value and type extracted
 /// from the attribute.  The attribute must have a type.
 void ConstantOp::build(Builder *builder, OperationState *result,
                        Attribute value) {
-  Type t = getAttributeType(value);
-  assert(t && "expected an attribute with a type");
-  return build(builder, result, t, value);
+  return build(builder, result, value.getType(), value);
 }
 
 void ConstantOp::print(OpAsmPrinter *p) {
@@ -1018,9 +1005,7 @@ LogicalResult ConstantOp::verify() {
     return success();
   }
 
-  auto attrType = getAttributeType(value);
-  if (!attrType)
-    return emitOpError("requires 'value' attribute to have a type");
+  auto attrType = value.getType();
   if (attrType != type)
     return emitOpError("requires the type of the 'value' attribute to match "
                        "that of the operation result");