From 3c9a8fb4b23aa956b8aa5f459ba50ef14bc12f80 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 7 Sep 2018 11:18:30 +0900 Subject: [PATCH] [enco] Introduce ConcatSpec class (#1403) This commit introduces ConcatSpec class which estimates the output shape of Concat layer. Signed-off-by: Jonghyun Park --- contrib/enco/frontend/caffe/src/ConcatSpec.cpp | 24 ++++++++++++++++++ contrib/enco/frontend/caffe/src/ConcatSpec.h | 29 ++++++++++++++++++++++ .../enco/frontend/caffe/src/ConcatSpec.test.cpp | 26 +++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 contrib/enco/frontend/caffe/src/ConcatSpec.cpp create mode 100644 contrib/enco/frontend/caffe/src/ConcatSpec.h create mode 100644 contrib/enco/frontend/caffe/src/ConcatSpec.test.cpp diff --git a/contrib/enco/frontend/caffe/src/ConcatSpec.cpp b/contrib/enco/frontend/caffe/src/ConcatSpec.cpp new file mode 100644 index 0000000..02eeaa4 --- /dev/null +++ b/contrib/enco/frontend/caffe/src/ConcatSpec.cpp @@ -0,0 +1,24 @@ +#include "ConcatSpec.h" + +#include + +using namespace nncc::core::ADT::tensor; + +nncc::core::ADT::tensor::Shape ConcatSpec::forward(const ShapeList &inputs) const +{ + assert(inputs.size() > 0); + + Shape output_shape = inputs.at(0); + + for (uint32_t n = 1; n < inputs.size(); ++n) + { + // The current implementation assumes that "inputs" is well-formed + // TODO Verify whether "inputs" is really well-formed + const auto &input_shape = inputs.at(n); + output_shape.dim(_axis) += input_shape.dim(_axis); + } + + return output_shape; +} + +ConcatSpec concat_spec(uint32_t axis) { return ConcatSpec{axis}; } diff --git a/contrib/enco/frontend/caffe/src/ConcatSpec.h b/contrib/enco/frontend/caffe/src/ConcatSpec.h new file mode 100644 index 0000000..0a831d9 --- /dev/null +++ b/contrib/enco/frontend/caffe/src/ConcatSpec.h @@ -0,0 +1,29 @@ +#ifndef __CONCAT_SPEC_H__ +#define __CONCAT_SPEC_H__ + +#include + +#include + +using ShapeList = std::vector; + +class ConcatSpec +{ +public: + explicit ConcatSpec(uint32_t axis) : _axis{axis} + { + // DO NOTHING + } + +public: + // @brief Return the output shape when inputs of given shape are + // concatenated along _axis + nncc::core::ADT::tensor::Shape forward(const ShapeList &) const; + +private: + uint32_t _axis; +}; + +ConcatSpec concat_spec(uint32_t axis); + +#endif // __CONCAT_SPEC_H__ diff --git a/contrib/enco/frontend/caffe/src/ConcatSpec.test.cpp b/contrib/enco/frontend/caffe/src/ConcatSpec.test.cpp new file mode 100644 index 0000000..5696990 --- /dev/null +++ b/contrib/enco/frontend/caffe/src/ConcatSpec.test.cpp @@ -0,0 +1,26 @@ +#include "ConcatSpec.h" + +#include + +using nncc::core::ADT::tensor::Shape; + +namespace +{ +class ConcatSpecTest : public ::testing::Test +{ + // FOR FUTURE USE +}; +} // namespace + +TEST_F(ConcatSpecTest, ifm_shape) +{ + const Shape in_1{1, 1, 4, 4}; + const Shape in_2{1, 2, 4, 4}; + const Shape in_3{1, 3, 4, 4}; + const Shape in_4{1, 4, 4, 4}; + + auto expected = Shape{1, 10, 4, 4}; + auto obtained = concat_spec(1).forward({in_1, in_2, in_3, in_4}); + + ASSERT_EQ(expected, obtained); +} -- 2.7.4