[llvm] Expose type and element count-related APIs on TensorSpec
authorMircea Trofin <mtrofin@google.com>
Tue, 4 Aug 2020 22:00:35 +0000 (15:00 -0700)
committerMircea Trofin <mtrofin@google.com>
Wed, 5 Aug 2020 00:32:16 +0000 (17:32 -0700)
Added a mechanism to check the element type, get the total element
count, and the size of an element.

Differential Revision: https://reviews.llvm.org/D85250

llvm/include/llvm/Analysis/Utils/TFUtils.h
llvm/lib/Analysis/TFUtils.cpp
llvm/unittests/Analysis/TFUtilsTest.cpp

index d445027..681560e 100644 (file)
@@ -66,10 +66,18 @@ public:
 
   bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
 
+  /// Get the number of elements in a tensor with this shape.
+  size_t getElementCount() const { return ElementCount; }
+  /// Get the size, in bytes, of one element.
+  size_t getElementByteSize() const;
+
+  template <typename T> bool isElementType() const {
+    return getDataType<T>() == TypeIndex;
+  }
+
 private:
   TensorSpec(const std::string &Name, int Port, int TypeIndex,
-             const std::vector<int64_t> &Shape)
-      : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape) {}
+             const std::vector<int64_t> &Shape);
 
   template <typename T> static int getDataType() {
     llvm_unreachable("Undefined tensor type");
@@ -79,6 +87,7 @@ private:
   int Port = 0;
   int TypeIndex = 0;
   std::vector<int64_t> Shape;
+  size_t ElementCount = 0;
 };
 
 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
index 8fd4011..b1be027 100644 (file)
@@ -24,6 +24,7 @@
 #include "tensorflow/c/c_api_experimental.h"
 
 #include <cassert>
+#include <numeric>
 
 using namespace llvm;
 
@@ -84,6 +85,16 @@ private:
   std::vector<TF_Tensor *> Output;
 };
 
+size_t TensorSpec::getElementByteSize() const {
+  return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
+}
+
+TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
+                       const std::vector<int64_t> &Shape)
+    : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
+      ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
+                                   std::multiplies<int64_t>())) {}
+
 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
                                            const json::Value &Value) {
   auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
index abdf2b2..9e4f2c7 100644 (file)
@@ -123,3 +123,18 @@ TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
   auto Spec = getTensorSpecFromJSON(Ctx, *Value);
   EXPECT_FALSE(Spec.hasValue());
 }
+
+TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
+  auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1});
+  auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1});
+  auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10});
+  auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10});
+  EXPECT_TRUE(Spec1D.isElementType<int16_t>());
+  EXPECT_FALSE(Spec3DLarge.isElementType<double>());
+  EXPECT_EQ(Spec1D.getElementCount(), 1);
+  EXPECT_EQ(Spec2D.getElementCount(), 1);
+  EXPECT_EQ(Spec1DLarge.getElementCount(), 10);
+  EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
+  EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
+  EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
+}
\ No newline at end of file