From daefaded4a3d30643e44c2ccc2f24adb3a125d75 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 31 May 2019 17:18:59 -0700 Subject: [PATCH] Consistently use int64_t for shape-related values in shaped types We want to support 64-bit shapes (even when the compiler is on a 32-bit architecture). Using int64_t consistently allows us to sidestep the bugginess of unsigned arithmetic. Still unsigned: kind, memory space, and bit width. The first two are basically enums. We could have a discussion about the last one, but it's basically just a very large enum as well and we're not doing any math on it, I think. -- PiperOrigin-RevId: 250985791 --- mlir/include/mlir/IR/StandardTypes.h | 6 +++--- mlir/lib/IR/Attributes.cpp | 6 +++--- mlir/lib/IR/StandardTypes.cpp | 11 +++++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 730ceed..9c44225 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -197,7 +197,7 @@ public: unsigned getElementTypeBitWidth() const; /// If this is a ranked type, return the number of elements. Otherwise, abort. - unsigned getNumElements() const; + int64_t getNumElements() const; /// If this is a ranked type, return the rank. Otherwise, abort. int64_t getRank() const; @@ -216,11 +216,11 @@ public: /// If this is a ranked type, return the number of dimensions with dynamic /// size. Otherwise, abort. - unsigned getNumDynamicDims() const; + int64_t getNumDynamicDims() const; /// If this is ranked type, return the size of the specified dimension. /// Otherwise, abort. - int64_t getDimSize(unsigned i) const; + int64_t getDimSize(int64_t i) const; /// Get the total amount of bits occupied by a value of this type. This does /// not take into account any memory layout or widening constraints, e.g. a diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 225fe41..0c4bb50 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -541,7 +541,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrFloat() && "expected int or float element type"); - assert(values.size() == type.getNumElements() && + assert(static_cast(values.size()) == type.getNumElements() && "expected 'values' to contain the same number of elements as 'type'"); // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored @@ -588,7 +588,7 @@ Attribute DenseElementsAttr::getValue(ArrayRef index) const { // Verify that the rank of the indices matches the held type. auto rank = type.getRank(); - if (static_cast(rank) != index.size()) + if (rank != static_cast(index.size())) return Attribute(); // Verify that all of the indices are within the shape dimensions. @@ -673,7 +673,7 @@ ArrayRef DenseElementsAttr::getRawData() const { // of 'type'. DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(values.size() == type.getNumElements() && + assert(static_cast(values.size()) == type.getNumElements() && "expected 'values' to contain the same number of elements as 'type'"); size_t bitWidth = getDenseElementBitwidth(type.getElementType()); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index c8e7e69..965d67e 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -115,10 +115,10 @@ unsigned ShapedType::getElementTypeBitWidth() const { return getElementType().getIntOrFloatBitWidth(); } -unsigned ShapedType::getNumElements() const { +int64_t ShapedType::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); auto shape = getShape(); - unsigned num = 1; + int64_t num = 1; for (auto dim : shape) num *= dim; return num; @@ -128,7 +128,10 @@ int64_t ShapedType::getRank() const { return getShape().size(); } bool ShapedType::hasRank() const { return !isa(); } -int64_t ShapedType::getDimSize(unsigned i) const { return getShape()[i]; } +int64_t ShapedType::getDimSize(int64_t i) const { + assert(i >= 0 && i < getRank() && "invalid index for shaped type"); + return getShape()[i]; +} /// Get the number of bits require to store a value of the given shaped type. /// Compute the value recursively since tensors are allowed to have vectors as @@ -162,7 +165,7 @@ ArrayRef ShapedType::getShape() const { } } -unsigned ShapedType::getNumDynamicDims() const { +int64_t ShapedType::getNumDynamicDims() const { return llvm::count_if(getShape(), isDynamic); } -- 2.7.4