Consistently use int64_t for shape-related values in shaped types
authorGeoffrey Martin-Noble <gcmn@google.com>
Sat, 1 Jun 2019 00:18:59 +0000 (17:18 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:13:58 +0000 (20:13 -0700)
    We want to support 64-bit shapes (even when the compiler is on a 32-bit architecture). Using int64_t consistently allows us to sidestep the bugginess of unsigned arithmetic.

    Still unsigned: kind, memory space, and bit width. The first two are basically enums. We could have a discussion about the last one, but it's basically just a very large enum as well and we're not doing any math on it, I think.

--

PiperOrigin-RevId: 250985791

mlir/include/mlir/IR/StandardTypes.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/StandardTypes.cpp

index 730ceed..9c44225 100644 (file)
@@ -197,7 +197,7 @@ public:
   unsigned getElementTypeBitWidth() const;
 
   /// If this is a ranked type, return the number of elements. Otherwise, abort.
-  unsigned getNumElements() const;
+  int64_t getNumElements() const;
 
   /// If this is a ranked type, return the rank. Otherwise, abort.
   int64_t getRank() const;
@@ -216,11 +216,11 @@ public:
 
   /// If this is a ranked type, return the number of dimensions with dynamic
   /// size. Otherwise, abort.
-  unsigned getNumDynamicDims() const;
+  int64_t getNumDynamicDims() const;
 
   /// If this is ranked type, return the size of the specified dimension.
   /// Otherwise, abort.
-  int64_t getDimSize(unsigned i) const;
+  int64_t getDimSize(int64_t i) const;
 
   /// Get the total amount of bits occupied by a value of this type.  This does
   /// not take into account any memory layout or widening constraints, e.g. a
index 225fe41..0c4bb50 100644 (file)
@@ -541,7 +541,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(type.getElementType().isIntOrFloat() &&
          "expected int or float element type");
-  assert(values.size() == type.getNumElements() &&
+  assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
          "expected 'values' to contain the same number of elements as 'type'");
 
   // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
@@ -588,7 +588,7 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
 
   // Verify that the rank of the indices matches the held type.
   auto rank = type.getRank();
-  if (static_cast<size_t>(rank) != index.size())
+  if (rank != static_cast<int64_t>(index.size()))
     return Attribute();
 
   // Verify that all of the indices are within the shape dimensions.
@@ -673,7 +673,7 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
 // of 'type'.
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APInt> values) {
-  assert(values.size() == type.getNumElements() &&
+  assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
          "expected 'values' to contain the same number of elements as 'type'");
 
   size_t bitWidth = getDenseElementBitwidth(type.getElementType());
index c8e7e69..965d67e 100644 (file)
@@ -115,10 +115,10 @@ unsigned ShapedType::getElementTypeBitWidth() const {
   return getElementType().getIntOrFloatBitWidth();
 }
 
-unsigned ShapedType::getNumElements() const {
+int64_t ShapedType::getNumElements() const {
   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
   auto shape = getShape();
-  unsigned num = 1;
+  int64_t num = 1;
   for (auto dim : shape)
     num *= dim;
   return num;
@@ -128,7 +128,10 @@ int64_t ShapedType::getRank() const { return getShape().size(); }
 
 bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
 
-int64_t ShapedType::getDimSize(unsigned i) const { return getShape()[i]; }
+int64_t ShapedType::getDimSize(int64_t i) const {
+  assert(i >= 0 && i < getRank() && "invalid index for shaped type");
+  return getShape()[i];
+}
 
 /// Get the number of bits require to store a value of the given shaped type.
 /// Compute the value recursively since tensors are allowed to have vectors as
@@ -162,7 +165,7 @@ ArrayRef<int64_t> ShapedType::getShape() const {
   }
 }
 
-unsigned ShapedType::getNumDynamicDims() const {
+int64_t ShapedType::getNumDynamicDims() const {
   return llvm::count_if(getShape(), isDynamic);
 }