From 321a796bae9baecf182a57f65f276ab8866a611b Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 17 May 2018 08:45:55 +0900 Subject: [PATCH] [core.ADT.tensor] Add 'num_elements' function (#217) * [core.ADT.tensor] Add 'num_elements' function This commit adds 'num_elemnets' function which takes a tensor shape and returns the number of elements in a tensor of the given shape. Signed-off-by: Jonghyun Park * Return 0 when null dimension exists --- libs/core/include/nncc/core/ADT/tensor/Shape.h | 2 ++ libs/core/src/nncc/core/ADT/tensor/Shape.cpp | 17 ++++++++++++++++ libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp | 24 +++++++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/libs/core/include/nncc/core/ADT/tensor/Shape.h b/libs/core/include/nncc/core/ADT/tensor/Shape.h index 75cf47e..4831681 100644 --- a/libs/core/include/nncc/core/ADT/tensor/Shape.h +++ b/libs/core/include/nncc/core/ADT/tensor/Shape.h @@ -37,6 +37,8 @@ private: std::vector _dims; }; +uint64_t num_elements(const Shape &); + Shape squeeze(const Shape &); bool operator==(const Shape &, const Shape &); diff --git a/libs/core/src/nncc/core/ADT/tensor/Shape.cpp b/libs/core/src/nncc/core/ADT/tensor/Shape.cpp index c8d1caa..b3dd4ab 100644 --- a/libs/core/src/nncc/core/ADT/tensor/Shape.cpp +++ b/libs/core/src/nncc/core/ADT/tensor/Shape.cpp @@ -28,6 +28,23 @@ Shape &Shape::squeeze(void) return *this; } +uint64_t num_elements(const Shape &shape) +{ + if (shape.rank() == 0) + { + return 0; + } + + uint64_t res = 1; + + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + res *= shape.dim(axis); + } + + return res; +} + Shape squeeze(const Shape &shape) { Shape res{shape}; diff --git a/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp b/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp index 7f17441..ab2dc87 100644 --- a/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp +++ b/libs/core/src/nncc/core/ADT/tensor/Shape.test.cpp @@ -58,6 +58,30 @@ TEST(ADT_TENSOR_SHAPE, copy) } } +TEST(ADT_TENSOR_SHAPE, num_elements_zero) +{ + using nncc::core::ADT::tensor::Shape; + using nncc::core::ADT::tensor::num_elements; + + ASSERT_EQ(num_elements(Shape{0, 0, 0, 0}), 0); +} + +TEST(ADT_TENSOR_SHAPE, num_elements_nonzero) +{ + using nncc::core::ADT::tensor::Shape; + using nncc::core::ADT::tensor::num_elements; + + ASSERT_EQ(num_elements(Shape{2, 3}), 6); +} + +TEST(ADT_TENSOR_SHAPE, num_elements_nulldim) +{ + using nncc::core::ADT::tensor::Shape; + using nncc::core::ADT::tensor::num_elements; + + ASSERT_EQ(num_elements(Shape{2, 0, 3}), 0); +} + TEST(ADT_TENSOR_SHAPE, squeeze_neg) { using nncc::core::ADT::tensor::Shape; -- 2.7.4