From f6649afa2aa2ed22550898f083a40c16932a6c0d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 3 Jul 2019 14:00:41 +0900 Subject: [PATCH] [loco] Support FeatureCodec clone (#4062) This commit introduce "clone" API to FeatureEncoder/Decoder interface. Signed-off-by: Jonghyun Park --- contrib/loco/CMakeLists.txt | 1 + contrib/loco/include/loco/IR/FeatureCodec.h | 6 ++++ contrib/loco/include/loco/IR/PermutingCodec.h | 16 +++++++++ contrib/loco/src/IR/PermutingCodec.cpp | 12 +++++++ contrib/loco/src/IR/PermutingCodec.test.cpp | 52 +++++++++++++++++++++++++++ 5 files changed, 87 insertions(+) diff --git a/contrib/loco/CMakeLists.txt b/contrib/loco/CMakeLists.txt index 9657764..8d1a857 100644 --- a/contrib/loco/CMakeLists.txt +++ b/contrib/loco/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(loco SHARED ${SOURCES}) target_include_directories(loco PUBLIC include) # TODO Remove dependencies on angkor library target_link_libraries(loco PUBLIC angkor) +target_link_libraries(loco PRIVATE stdex) # Let's apply nncc common compile options # # NOTE This will enable strict compilation (warnings as error). diff --git a/contrib/loco/include/loco/IR/FeatureCodec.h b/contrib/loco/include/loco/IR/FeatureCodec.h index 12ceb63..93094e1 100644 --- a/contrib/loco/include/loco/IR/FeatureCodec.h +++ b/contrib/loco/include/loco/IR/FeatureCodec.h @@ -23,6 +23,8 @@ #include "loco/IR/TensorShape.h" #include "loco/IR/TensorIndex.h" +#include + namespace loco { @@ -43,6 +45,8 @@ struct FeatureEncoder virtual FeatureShape shape(const TensorShape &shape) const = 0; virtual TensorIndex value(const FeatureIndex &index) const = 0; + + virtual std::unique_ptr clone(void) const = 0; }; /** @@ -64,6 +68,8 @@ struct FeatureDecoder virtual TensorShape shape(const FeatureShape &) const = 0; virtual FeatureIndex value(const TensorIndex &) const = 0; + + virtual std::unique_ptr clone(void) const = 0; }; } // namespace loco diff --git a/contrib/loco/include/loco/IR/PermutingCodec.h b/contrib/loco/include/loco/IR/PermutingCodec.h index de83db6..15c2543 100644 --- a/contrib/loco/include/loco/IR/PermutingCodec.h +++ b/contrib/loco/include/loco/IR/PermutingCodec.h @@ -76,12 +76,20 @@ public: PermutingEncoder() = default; public: + PermutingEncoder(const Permutation &perm) : _perm{perm} + { + // DO NOTHING + } + +public: bool valid(void) const; public: FeatureShape shape(const TensorShape &tensor_shape) const override; TensorIndex value(const FeatureIndex &index) const override; + std::unique_ptr clone(void) const override; + public: const Permutation *perm(void) const { return &_perm; } Permutation *perm(void) { return &_perm; } @@ -97,12 +105,20 @@ public: PermutingDecoder() = default; public: + PermutingDecoder(const Permutation &perm) : _perm{perm} + { + // DO NOTHING + } + +public: bool valid(void) const; public: TensorShape shape(const FeatureShape &tensor_shape) const override; FeatureIndex value(const TensorIndex &index) const override; + std::unique_ptr clone(void) const override; + public: const Permutation *perm(void) const { return &_perm; } Permutation *perm(void) { return &_perm; } diff --git a/contrib/loco/src/IR/PermutingCodec.cpp b/contrib/loco/src/IR/PermutingCodec.cpp index bdd4f7a..48ad983 100644 --- a/contrib/loco/src/IR/PermutingCodec.cpp +++ b/contrib/loco/src/IR/PermutingCodec.cpp @@ -16,6 +16,8 @@ #include "loco/IR/PermutingCodec.h" +#include + #include #include #include @@ -135,6 +137,11 @@ TensorIndex PermutingEncoder::value(const FeatureIndex &in) con return out; } +std::unique_ptr PermutingEncoder::clone(void) const +{ + return stdex::make_unique>(_perm); +} + bool PermutingEncoder::valid(void) const { return ::valid(_perm); } // @@ -170,6 +177,11 @@ FeatureIndex PermutingDecoder::value(const TensorIndex &in) con return out; } +std::unique_ptr PermutingDecoder::clone(void) const +{ + return stdex::make_unique>(_perm); +} + bool PermutingDecoder::valid(void) const { return ::valid(_perm); } } // namespace loco diff --git a/contrib/loco/src/IR/PermutingCodec.test.cpp b/contrib/loco/src/IR/PermutingCodec.test.cpp index 83b133a..ad3ee88 100644 --- a/contrib/loco/src/IR/PermutingCodec.test.cpp +++ b/contrib/loco/src/IR/PermutingCodec.test.cpp @@ -139,6 +139,32 @@ TEST(PermutingEncoderTest, feature) ASSERT_EQ(tensor_index.at(3), 1); // CHANNEL(DEPTH) } +TEST(PermutingEncoderTest, feature_clone) +{ + PermutingEncoder src_enc; + + auto src_perm = src_enc.perm(); + + src_perm->axis(FeatureAxis::Count) = 0; + src_perm->axis(FeatureAxis::Depth) = 3; + src_perm->axis(FeatureAxis::Height) = 1; + src_perm->axis(FeatureAxis::Width) = 2; + + auto dst_enc = src_enc.clone(); + auto dst_perm = dynamic_cast *>(dst_enc.get())->perm(); + + EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width)); + + // Update on cloned encoder SHOULD NOT affect the original encoder + dst_perm->axis(FeatureAxis::Height) += 1; + + EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2); +} + TEST(PermutingEncoderTest, filter) { PermutingEncoder enc; @@ -261,3 +287,29 @@ TEST(PermutingDecoderTest, feature) ASSERT_EQ(feature_index.row(), 2); ASSERT_EQ(feature_index.column(), 3); } + +TEST(PermutingDecoderTest, feature_clone) +{ + PermutingDecoder src_enc; + + auto src_perm = src_enc.perm(); + + src_perm->axis(FeatureAxis::Count) = 0; + src_perm->axis(FeatureAxis::Depth) = 3; + src_perm->axis(FeatureAxis::Height) = 1; + src_perm->axis(FeatureAxis::Width) = 2; + + auto dst_enc = src_enc.clone(); + auto dst_perm = dynamic_cast *>(dst_enc.get())->perm(); + + EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width)); + + // Update on cloned decoder SHOULD NOT affect the original decoder + dst_perm->axis(FeatureAxis::Height) += 1; + + EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2); +} -- 2.7.4