[core.ADT.tensor] Add 'num_elements' function (#217)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 16 May 2018 23:45:55 +0000 (08:45 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 16 May 2018 23:45:55 +0000 (08:45 +0900)
* [core.ADT.tensor] Add 'num_elements' function

This commit adds 'num_elemnets' function which takes a tensor shape and
returns the number of elements in a tensor of the given shape.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Return 0 when null dimension exists

libs/core/include/nncc/core/ADT/tensor/Shape.h
libs/core/src/nncc/core/ADT/tensor/Shape.cpp
libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp

index 75cf47e..4831681 100644 (file)
@@ -37,6 +37,8 @@ private:
   std::vector<uint32_t> _dims;
 };
 
+uint64_t num_elements(const Shape &);
+
 Shape squeeze(const Shape &);
 
 bool operator==(const Shape &, const Shape &);
index c8d1caa..b3dd4ab 100644 (file)
@@ -28,6 +28,23 @@ Shape &Shape::squeeze(void)
   return *this;
 }
 
+uint64_t num_elements(const Shape &shape)
+{
+  if (shape.rank() == 0)
+  {
+    return 0;
+  }
+
+  uint64_t res = 1;
+
+  for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+  {
+    res *= shape.dim(axis);
+  }
+
+  return res;
+}
+
 Shape squeeze(const Shape &shape)
 {
   Shape res{shape};
index 7f17441..ab2dc87 100644 (file)
@@ -58,6 +58,30 @@ TEST(ADT_TENSOR_SHAPE, copy)
   }
 }
 
+TEST(ADT_TENSOR_SHAPE, num_elements_zero)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::num_elements;
+
+  ASSERT_EQ(num_elements(Shape{0, 0, 0, 0}), 0);
+}
+
+TEST(ADT_TENSOR_SHAPE, num_elements_nonzero)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::num_elements;
+
+  ASSERT_EQ(num_elements(Shape{2, 3}), 6);
+}
+
+TEST(ADT_TENSOR_SHAPE, num_elements_nulldim)
+{
+  using nncc::core::ADT::tensor::Shape;
+  using nncc::core::ADT::tensor::num_elements;
+
+  ASSERT_EQ(num_elements(Shape{2, 0, 3}), 0);
+}
+
 TEST(ADT_TENSOR_SHAPE, squeeze_neg)
 {
   using nncc::core::ADT::tensor::Shape;