From d4785f7ba934e906ad1135dcd717cbc66063e34a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 21 Oct 2019 14:26:24 +0900 Subject: [PATCH] [loco] PemutingDecoder for DepthwiseFilter (#8331) * [loco] PemutingDecoder for DepthwiseFilter This commit introduces PemutingDecoder for DepthwiseFilter Signed-off-by: Cheongyo Bahk * Remove unrelated --- compiler/loco/include/loco/IR/PermutingCodec.h | 30 +++++++++++++ compiler/loco/src/IR/PermutingCodec.cpp | 35 +++++++++++++++ compiler/loco/src/IR/PermutingCodec.test.cpp | 60 ++++++++++++++++++++++++++ 3 files changed, 125 insertions(+) diff --git a/compiler/loco/include/loco/IR/PermutingCodec.h b/compiler/loco/include/loco/IR/PermutingCodec.h index 71a2a65..16be919 100644 --- a/compiler/loco/include/loco/IR/PermutingCodec.h +++ b/compiler/loco/include/loco/IR/PermutingCodec.h @@ -280,6 +280,36 @@ private: }; /** + * @brief Permutation-based DepthwiseFilter-to-Tensor converter + */ +template <> class PermutingDecoder final : public DepthwiseFilterDecoder +{ +public: + PermutingDecoder() = default; + +public: + PermutingDecoder(const Permutation &perm) : _perm{perm} + { + // DO NOTHING + } + +public: + bool valid(void) const; + +public: + TensorShape shape(const DepthwiseFilterShape &shape) const override; + DepthwiseFilterIndex value(const TensorIndex &index) const override; + +public: + const Permutation *perm(void) const { return &_perm; } + Permutation *perm(void) { return &_perm; } + void perm(const Permutation &p) { _perm = p; } + +private: + Permutation _perm; +}; + +/** * @brief Mapping between Matrix/Tensor Axis */ template <> class Permutation diff --git a/compiler/loco/src/IR/PermutingCodec.cpp b/compiler/loco/src/IR/PermutingCodec.cpp index 5d8156f..2857e5e 100644 --- a/compiler/loco/src/IR/PermutingCodec.cpp +++ b/compiler/loco/src/IR/PermutingCodec.cpp @@ -456,6 +456,41 @@ TensorIndex PermutingEncoder::value(const DepthwiseFilt bool PermutingEncoder::valid(void) const { return ::valid(_perm); } +// +// Permuting Decoder +// +TensorShape PermutingDecoder::shape(const DepthwiseFilterShape &in) const +{ + assert(valid() && "invalid permutation"); + + TensorShape out; + out.rank(4); + + out.dim(_perm[DepthwiseFilterAxis::Depth]) = in.depth(); + out.dim(_perm[DepthwiseFilterAxis::Multiplier]) = in.multiplier(); + out.dim(_perm[DepthwiseFilterAxis::Height]) = in.height(); + out.dim(_perm[DepthwiseFilterAxis::Width]) = in.width(); + + return out; +} + +DepthwiseFilterIndex PermutingDecoder::value(const TensorIndex &in) const +{ + assert(valid() && "invalid permutation"); + assert(in.rank() == 4); + + DepthwiseFilterIndex out; + + out.channel() = in.at(_perm[DepthwiseFilterAxis::Depth]); + out.nth() = in.at(_perm[DepthwiseFilterAxis::Multiplier]); + out.row() = in.at(_perm[DepthwiseFilterAxis::Height]); + out.column() = in.at(_perm[DepthwiseFilterAxis::Width]); + + return out; +} + +bool PermutingDecoder::valid(void) const { return ::valid(_perm); } + } // namespace loco /** diff --git a/compiler/loco/src/IR/PermutingCodec.test.cpp b/compiler/loco/src/IR/PermutingCodec.test.cpp index 93c6de3..4e090c3 100644 --- a/compiler/loco/src/IR/PermutingCodec.test.cpp +++ b/compiler/loco/src/IR/PermutingCodec.test.cpp @@ -491,3 +491,63 @@ TEST(PermutingDecoderTest, filter) ASSERT_EQ(filter_index.row(), 2); ASSERT_EQ(filter_index.column(), 3); } + +TEST(PermutingDecoderTest, depthwise_filter) +{ + PermutingDecoder dec; + + // Decoder is invalid at the beginning + ASSERT_FALSE(dec.valid()); + + // Set "invalid" mapping + dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0; + dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6; + dec.perm()->axis(DepthwiseFilterAxis::Height) = 1; + dec.perm()->axis(DepthwiseFilterAxis::Width) = 2; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set another "invalid" mapping + dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set "valid" mapping + dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3; + + // Decoder is now valid + ASSERT_TRUE(dec.valid()); + + DepthwiseFilterShape dw_filter_shape; + + dw_filter_shape.depth() = 8; + dw_filter_shape.multiplier() = 1; + dw_filter_shape.height() = 7; + dw_filter_shape.width() = 4; + + // Get the corresponding depthwise filter shape + auto tensor_shape = dec.shape(dw_filter_shape); + + ASSERT_EQ(tensor_shape.dim(0).value(), 8); + ASSERT_EQ(tensor_shape.dim(1).value(), 7); + ASSERT_EQ(tensor_shape.dim(2).value(), 4); + ASSERT_EQ(tensor_shape.dim(3).value(), 1); + + // Let's find a source tensor index! + TensorIndex tensor_index; + tensor_index.resize(4); + + tensor_index.at(0) = 4; + tensor_index.at(1) = 2; + tensor_index.at(2) = 1; + tensor_index.at(3) = 0; + + auto dw_filter_index = dec.value(tensor_index); + + ASSERT_EQ(dw_filter_index.channel(), 4); + ASSERT_EQ(dw_filter_index.nth(), 0); + ASSERT_EQ(dw_filter_index.row(), 2); + ASSERT_EQ(dw_filter_index.column(), 1); +} -- 2.7.4