From 90b9c49ca6477a85e69018967c0a4d4d38ee6e72 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Tue, 4 Aug 2020 15:00:35 -0700 Subject: [PATCH] [llvm] Expose type and element count-related APIs on TensorSpec Added a mechanism to check the element type, get the total element count, and the size of an element. Differential Revision: https://reviews.llvm.org/D85250 --- llvm/include/llvm/Analysis/Utils/TFUtils.h | 13 +++++++++++-- llvm/lib/Analysis/TFUtils.cpp | 11 +++++++++++ llvm/unittests/Analysis/TFUtilsTest.cpp | 15 +++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h index d445027..681560e 100644 --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -66,10 +66,18 @@ public: bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } + /// Get the number of elements in a tensor with this shape. + size_t getElementCount() const { return ElementCount; } + /// Get the size, in bytes, of one element. + size_t getElementByteSize() const; + + template bool isElementType() const { + return getDataType() == TypeIndex; + } + private: TensorSpec(const std::string &Name, int Port, int TypeIndex, - const std::vector &Shape) - : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape) {} + const std::vector &Shape); template static int getDataType() { llvm_unreachable("Undefined tensor type"); @@ -79,6 +87,7 @@ private: int Port = 0; int TypeIndex = 0; std::vector Shape; + size_t ElementCount = 0; }; Optional getTensorSpecFromJSON(LLVMContext &Ctx, diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp index 8fd4011..b1be027 100644 --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -24,6 +24,7 @@ #include "tensorflow/c/c_api_experimental.h" #include +#include using namespace llvm; @@ -84,6 +85,16 @@ private: std::vector Output; }; +size_t TensorSpec::getElementByteSize() const { + return TF_DataTypeSize(static_cast(TypeIndex)); +} + +TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex, + const std::vector &Shape) + : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape), + ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, + std::multiplies())) {} + Optional getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value) { auto EmitError = [&](const llvm::Twine &Message) -> Optional { diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp index abdf2b2..9e4f2c7 100644 --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -123,3 +123,18 @@ TEST(TFUtilsTest, JSONParsingInvalidTensorType) { auto Spec = getTensorSpecFromJSON(Ctx, *Value); EXPECT_FALSE(Spec.hasValue()); } + +TEST(TFUtilsTest, TensorSpecSizesAndTypes) { + auto Spec1D = TensorSpec::createSpec("Hi1", {1}); + auto Spec2D = TensorSpec::createSpec("Hi2", {1, 1}); + auto Spec1DLarge = TensorSpec::createSpec("Hi3", {10}); + auto Spec3DLarge = TensorSpec::createSpec("Hi3", {2, 4, 10}); + EXPECT_TRUE(Spec1D.isElementType()); + EXPECT_FALSE(Spec3DLarge.isElementType()); + EXPECT_EQ(Spec1D.getElementCount(), 1); + EXPECT_EQ(Spec2D.getElementCount(), 1); + EXPECT_EQ(Spec1DLarge.getElementCount(), 10); + EXPECT_EQ(Spec3DLarge.getElementCount(), 80); + EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); + EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); +} \ No newline at end of file -- 2.7.4