From c9afa0f25de0ea91a436954af198ed70c5f4c49c Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 5 Jun 2019 09:28:50 +0900 Subject: [PATCH] [loco] Introduce Filter Permuting Encoder (#3684) This commit introduces a permuting encoder for filter values. Signed-off-by: Jonghyun Park --- contrib/loco/include/loco/IR/PermutingCodec.h | 25 +++++++++++ contrib/loco/src/IR/PermutingCodec.cpp | 35 ++++++++++++++++ contrib/loco/src/IR/PermutingCodec.test.cpp | 60 +++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/contrib/loco/include/loco/IR/PermutingCodec.h b/contrib/loco/include/loco/IR/PermutingCodec.h index 8250604..de83db6 100644 --- a/contrib/loco/include/loco/IR/PermutingCodec.h +++ b/contrib/loco/include/loco/IR/PermutingCodec.h @@ -22,6 +22,7 @@ #include "loco/IR/FeatureAxis.h" #include "loco/IR/FeatureCodec.h" #include "loco/IR/FilterAxis.h" +#include "loco/IR/FilterCodec.h" #include @@ -146,6 +147,30 @@ private: std::map _map; }; +/** + * @brief Permutation-based Tensor-to-Filter converter + */ +template <> class PermutingEncoder final : public FilterEncoder +{ +public: + PermutingEncoder() = default; + +public: + bool valid(void) const; + +public: + FilterShape shape(const TensorShape &tensor_shape) const override; + TensorIndex value(const FilterIndex &index) const override; + +public: + const Permutation *perm(void) const { return &_perm; } + Permutation *perm(void) { return &_perm; } + void perm(const Permutation &p) { _perm = p; } + +private: + Permutation _perm; +}; + } // namespace loco #endif // __LOCO_IR_PERMUTING_CODEC_H__ diff --git a/contrib/loco/src/IR/PermutingCodec.cpp b/contrib/loco/src/IR/PermutingCodec.cpp index fab91ea..bdd4f7a 100644 --- a/contrib/loco/src/IR/PermutingCodec.cpp +++ b/contrib/loco/src/IR/PermutingCodec.cpp @@ -256,4 +256,39 @@ uint32_t &Permutation::axis(const FilterAxis &axis_f) return _map[axis_f]; } +// +// Permuting Encoder +// +FilterShape PermutingEncoder::shape(const TensorShape &in) const +{ + assert(valid() && "invalid permutation"); + + FilterShape out; + + out.count() = in.dim(_perm[FilterAxis::Count]); + out.depth() = in.dim(_perm[FilterAxis::Depth]); + out.height() = in.dim(_perm[FilterAxis::Height]); + out.width() = in.dim(_perm[FilterAxis::Width]); + + return out; +} + +TensorIndex PermutingEncoder::value(const FilterIndex &in) const +{ + assert(valid() && "invalid permutation"); + + TensorIndex out; + + out.resize(4); + + out.at(_perm[FilterAxis::Count]) = in.nth(); + out.at(_perm[FilterAxis::Depth]) = in.channel(); + out.at(_perm[FilterAxis::Height]) = in.row(); + out.at(_perm[FilterAxis::Width]) = in.column(); + + return out; +} + +bool PermutingEncoder::valid(void) const { return ::valid(_perm); } + } // namespace loco diff --git a/contrib/loco/src/IR/PermutingCodec.test.cpp b/contrib/loco/src/IR/PermutingCodec.test.cpp index 3dffdd7..83b133a 100644 --- a/contrib/loco/src/IR/PermutingCodec.test.cpp +++ b/contrib/loco/src/IR/PermutingCodec.test.cpp @@ -139,6 +139,66 @@ TEST(PermutingEncoderTest, feature) ASSERT_EQ(tensor_index.at(3), 1); // CHANNEL(DEPTH) } +TEST(PermutingEncoderTest, filter) +{ + PermutingEncoder enc; + + // Encoder is invalid at the beginning + ASSERT_FALSE(enc.valid()); + + // Set "invalid" mapping + enc.perm()->axis(FilterAxis::Count) = 0; + enc.perm()->axis(FilterAxis::Depth) = 6; + enc.perm()->axis(FilterAxis::Height) = 1; + enc.perm()->axis(FilterAxis::Width) = 2; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set another "invalid" mapping + enc.perm()->axis(FilterAxis::Depth) = 1; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set "valid" mapping + enc.perm()->axis(FilterAxis::Depth) = 3; + + // Encoder is now valid + ASSERT_TRUE(enc.valid()); + + TensorShape tensor_shape; + + tensor_shape.rank(4); + tensor_shape.dim(0) = 8; // COUNT + tensor_shape.dim(1) = 1; // HEIGHT + tensor_shape.dim(2) = 7; // WIDTH + tensor_shape.dim(3) = 4; // DEPTH + + // Get the corresponding filter shape + auto filter_shape = enc.shape(tensor_shape); + + ASSERT_EQ(filter_shape.count(), 8); + ASSERT_EQ(filter_shape.depth(), 4); + ASSERT_EQ(filter_shape.height(), 1); + ASSERT_EQ(filter_shape.width(), 7); + + // Let's find a source tensor index! + FilterIndex filter_index; + + filter_index.nth() = 1; + filter_index.channel() = 2; + filter_index.row() = 0; + filter_index.column() = 3; + + auto tensor_index = enc.value(filter_index); + + ASSERT_EQ(tensor_index.at(0), 1); // NTH(COUNT) + ASSERT_EQ(tensor_index.at(1), 0); // ROW(HEIGHT) + ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH) + ASSERT_EQ(tensor_index.at(3), 2); // CHANNEL(DEPTH) +} + TEST(PermutingDecoderTest, feature) { PermutingDecoder dec; -- 2.7.4