From 58a766705bcf38e7d4d132979b7402789376cc6d Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 24 May 2019 17:44:21 -0700 Subject: [PATCH] Limit the number of places where shaped type has to explicitly reference its base classes. Introduces a hasRank() method to make checking for rank a bit easier. This is partially to make it easier to make MemRef subclass ShapedType -- PiperOrigin-RevId: 249927442 --- mlir/include/mlir/IR/StandardTypes.h | 32 +++++++++++------------ mlir/lib/IR/StandardTypes.cpp | 49 +++++++++++------------------------- 2 files changed, 31 insertions(+), 50 deletions(-) diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index de11f87..867fb8d 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -192,38 +192,38 @@ public: /// Return the element type. Type getElementType() const; - /// If an element type is an integer or a float, return its width. Abort - /// otherwise. + /// If an element type is an integer or a float, return its width. Otherwise, + /// abort. unsigned getElementTypeBitWidth() const; - /// If this is ranked tensor or vector type, return the number of elements. If - /// it is an unranked tensor, abort. + /// If this is a ranked type, return the number of elements. Otherwise, abort. unsigned getNumElements() const; - /// If this is ranked tensor or vector type, return the rank. If it is an - /// unranked tensor, return -1. + /// If this is a ranked type, return the rank. Otherwise, return -1. int64_t getRank() const; - /// If this is ranked tensor or vector type, return the shape. If it is an - /// unranked tensor, abort. + /// Whether or not this is a ranked type. Vector and ranked tensors have a + /// rank, while unranked tensors do not. + bool hasRank() const; + + /// If this is a ranked type, return the shape. Otherwise, abort. ArrayRef getShape() const; - /// If this is unranked tensor or any dimension has unknown size (<0), - /// it doesn't have static shape. If all dimensions have known size (>= 0), - /// it has static shape. + /// If this is unranked type or any dimension has unknown size (<0), it + /// doesn't have static shape. If all dimensions have known size (>= 0), it + /// has static shape. bool hasStaticShape() const; - /// If this is ranked tensor or vector type, return the size of the specified - /// dimension. It aborts if the tensor is unranked (this can be checked by - /// the getRank call method). + /// If this is ranked type, return the size of the specified dimension. + /// Otherwise, abort. int64_t getDimSize(unsigned i) const; /// Get the total amount of bits occupied by a value of this type. This does /// not take into account any memory layout or widening constraints, e.g. a /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice /// it will likely be stored as in a 4xi64 vector register. Fail an assertion - /// if the size cannot be computed statically, i.e. if the tensor has a - /// dynamic shape or if its elemental type does not have a known bit width. + /// if the size cannot be computed statically, i.e. if the type has a dynamic + /// shape or if its elemental type does not have a known bit width. int64_t getSizeInBits() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 1958a43..03bce65 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -116,48 +116,29 @@ unsigned ShapedType::getElementTypeBitWidth() const { } unsigned ShapedType::getNumElements() const { - switch (getKind()) { - case StandardTypes::Vector: - case StandardTypes::RankedTensor: { - assert(hasStaticShape() && "expected type to have static shape"); - auto shape = getShape(); - unsigned num = 1; - for (auto dim : shape) - num *= dim; - return num; - } - default: - llvm_unreachable("not a ShapedType or not ranked"); - } + assert(hasStaticShape() && "expected type to have static shape"); + auto shape = getShape(); + unsigned num = 1; + for (auto dim : shape) + num *= dim; + return num; } -/// If this is ranked tensor or vector type, return the rank. If it is an -/// unranked tensor, return -1. int64_t ShapedType::getRank() const { - switch (getKind()) { - case StandardTypes::Vector: - case StandardTypes::RankedTensor: - return getShape().size(); - case StandardTypes::UnrankedTensor: - return -1; - default: - llvm_unreachable("not a ShapedType"); - } + return hasRank() ? getShape().size() : -1; } +bool ShapedType::hasRank() const { return !isa(); } + int64_t ShapedType::getDimSize(unsigned i) const { - switch (getKind()) { - case StandardTypes::Vector: - case StandardTypes::RankedTensor: + if (hasRank()) return getShape()[i]; - default: - llvm_unreachable("not a ShapedType or not ranked"); - } + llvm_unreachable("not a ShapedType or not ranked"); } -// Get the number of number of bits require to store a value of the given vector -// or tensor types. Compute the value recursively since tensors are allowed to -// have vectors as elements. +/// Get the number of bits require to store a value of the given shaped type. +/// Compute the value recursively since tensors are allowed to have vectors as +/// elements. int64_t ShapedType::getSizeInBits() const { assert(hasStaticShape() && "cannot get the bit size of an aggregate with a dynamic shape"); @@ -185,7 +166,7 @@ ArrayRef ShapedType::getShape() const { } bool ShapedType::hasStaticShape() const { - if (isa()) + if (!hasRank()) return false; return llvm::none_of(getShape(), [](int64_t i) { return i < 0; }); } -- 2.7.4