Make it clear that ElementsAttr is only for static shaped vectors or tensors.
authorGeoffrey Martin-Noble <gcmn@google.com>
Wed, 29 May 2019 22:34:50 +0000 (15:34 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:09:12 +0000 (20:09 -0700)
    This is in preparation for making MemRef a subclass of ShapedType, but also UnrankedTensor should already be excluded.

--

PiperOrigin-RevId: 250580197

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/Attributes.cpp

index ed9ae2c..8c2a949 100644 (file)
@@ -390,11 +390,14 @@ public:
   }
 };
 
-/// A base attribute that represents a reference to a vector or tensor constant.
+/// A base attribute that represents a reference to a static shaped tensor or
+/// vector constant.
 class ElementsAttr : public Attribute {
 public:
   using Attribute::Attribute;
 
+  /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
+  /// with static shape.
   ShapedType getType() const;
 
   /// Return the value at the given index. If index does not refer to a valid
@@ -431,6 +434,7 @@ public:
   using Base::Base;
   using ValueType = Attribute;
 
+  /// 'type' must be a vector or tensor with static shape.
   static SplatElementsAttr get(ShapedType type, Attribute elt);
   Attribute getValue() const;
 
@@ -462,11 +466,13 @@ public:
   using ImplType = detail::DenseElementsAttributeStorage;
 
   /// It assumes the elements in the input array have been truncated to the bits
-  /// width specified by the element type.
+  /// width specified by the element type. 'type' must be a vector or tensor
+  /// with static shape.
   static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
 
-  // Constructs a dense elements attribute from an array of element values. Each
-  // element attribute value is expected to be an element of 'type'.
+  /// Constructs a dense elements attribute from an array of element values.
+  /// Each element attribute value is expected to be an element of 'type'.
+  /// 'type' must be a vector or tensor with static shape.
   static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
 
   /// Returns the number of elements held by this attribute.
@@ -558,9 +564,9 @@ protected:
     return RawElementIterator(*this, size());
   }
 
-  // Constructs a dense elements attribute from an array of raw APInt values.
-  // Each APInt value is expected to have the same bitwidth as the element type
-  // of 'type'.
+  /// Constructs a dense elements attribute from an array of raw APInt values.
+  /// Each APInt value is expected to have the same bitwidth as the element type
+  /// of 'type'. 'type' must be a vector or tensor with static shape.
   static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
 };
 
@@ -580,12 +586,13 @@ public:
 
   /// Constructs a dense integer elements attribute from an array of APInt
   /// values. Each APInt value is expected to have the same bitwidth as the
-  /// element type of 'type'.
+  /// element type of 'type'. 'type' must be a vector or tensor with static
+  /// shape.
   static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
 
   /// Constructs a dense integer elements attribute from an array of integer
   /// values. Each value is expected to be within the bitwidth of the element
-  /// type of 'type'.
+  /// type of 'type'. 'type' must be a vector or tensor with static shape.
   static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
 
   /// Generates a new DenseElementsAttr by mapping each value attribute, and
@@ -629,9 +636,10 @@ public:
   using DenseElementsAttr::get;
   using DenseElementsAttr::getValues;
 
-  // Constructs a dense float elements attribute from an array of APFloat
-  // values. Each APFloat value is expected to have the same bitwidth as the
-  // element type of 'type'.
+  /// Constructs a dense float elements attribute from an array of APFloat
+  /// values. Each APFloat value is expected to have the same bitwidth as the
+  /// element type of 'type'. 'type' must be a vector or tensor with static
+  /// shape.
   static DenseFPElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
 
   /// Gets the float value of each of the dense elements.
@@ -712,6 +720,7 @@ class SparseElementsAttr
 public:
   using Base::Base;
 
+  /// 'type' must be a vector or tensor with static shape.
   static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices,
                                 DenseElementsAttr values);
 
index ac1f180..e20d7b5 100644 (file)
@@ -322,6 +322,9 @@ ElementsAttr ElementsAttr::mapValues(
 SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
   assert(elt.getType() == type.getElementType() &&
          "value should be of the given element type");
+  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+         "type must be ranked tensor or vector");
+  assert(type.hasStaticShape() && "type must have static shape");
   return Base::get(type.getContext(), StandardAttributes::SplatElements, type,
                    elt);
 }
@@ -424,6 +427,9 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
   assert((static_cast<uint64_t>(type.getSizeInBits()) <=
           data.size() * APInt::APINT_WORD_SIZE) &&
          "Input data bit size should be larger than that type requires");
+  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+         "type must be ranked tensor or vector");
+  assert(type.hasStaticShape() && "type must have static shape");
   switch (type.getElementType().getKind()) {
   case StandardTypes::BF16:
   case StandardTypes::F16:
@@ -797,6 +803,9 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
                                            DenseElementsAttr values) {
   assert(indices.getType().getElementType().isInteger(64) &&
          "expected sparse indices to be 64-bit integer values");
+  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+         "type must be ranked tensor or vector");
+  assert(type.hasStaticShape() && "type must have static shape");
   return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
                    indices, values);
 }