[loco] PemutingDecoder for DepthwiseFilter (#8331)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Mon, 21 Oct 2019 05:26:24 +0000 (14:26 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 21 Oct 2019 05:26:24 +0000 (14:26 +0900)
* [loco] PemutingDecoder for DepthwiseFilter

This commit introduces PemutingDecoder for DepthwiseFilter

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
* Remove unrelated

compiler/loco/include/loco/IR/PermutingCodec.h
compiler/loco/src/IR/PermutingCodec.cpp
compiler/loco/src/IR/PermutingCodec.test.cpp

index 71a2a65..16be919 100644 (file)
@@ -280,6 +280,36 @@ private:
 };
 
 /**
+ * @brief Permutation-based DepthwiseFilter-to-Tensor converter
+ */
+template <> class PermutingDecoder<Domain::DepthwiseFilter> final : public DepthwiseFilterDecoder
+{
+public:
+  PermutingDecoder() = default;
+
+public:
+  PermutingDecoder(const Permutation<Domain::DepthwiseFilter> &perm) : _perm{perm}
+  {
+    // DO NOTHING
+  }
+
+public:
+  bool valid(void) const;
+
+public:
+  TensorShape shape(const DepthwiseFilterShape &shape) const override;
+  DepthwiseFilterIndex value(const TensorIndex &index) const override;
+
+public:
+  const Permutation<Domain::DepthwiseFilter> *perm(void) const { return &_perm; }
+  Permutation<Domain::DepthwiseFilter> *perm(void) { return &_perm; }
+  void perm(const Permutation<Domain::DepthwiseFilter> &p) { _perm = p; }
+
+private:
+  Permutation<Domain::DepthwiseFilter> _perm;
+};
+
+/**
  * @brief Mapping between Matrix/Tensor Axis
  */
 template <> class Permutation<Domain::Matrix>
index 5d8156f..2857e5e 100644 (file)
@@ -456,6 +456,41 @@ TensorIndex PermutingEncoder<Domain::DepthwiseFilter>::value(const DepthwiseFilt
 
 bool PermutingEncoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); }
 
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::DepthwiseFilter>::shape(const DepthwiseFilterShape &in) const
+{
+  assert(valid() && "invalid permutation");
+
+  TensorShape out;
+  out.rank(4);
+
+  out.dim(_perm[DepthwiseFilterAxis::Depth]) = in.depth();
+  out.dim(_perm[DepthwiseFilterAxis::Multiplier]) = in.multiplier();
+  out.dim(_perm[DepthwiseFilterAxis::Height]) = in.height();
+  out.dim(_perm[DepthwiseFilterAxis::Width]) = in.width();
+
+  return out;
+}
+
+DepthwiseFilterIndex PermutingDecoder<Domain::DepthwiseFilter>::value(const TensorIndex &in) const
+{
+  assert(valid() && "invalid permutation");
+  assert(in.rank() == 4);
+
+  DepthwiseFilterIndex out;
+
+  out.channel() = in.at(_perm[DepthwiseFilterAxis::Depth]);
+  out.nth() = in.at(_perm[DepthwiseFilterAxis::Multiplier]);
+  out.row() = in.at(_perm[DepthwiseFilterAxis::Height]);
+  out.column() = in.at(_perm[DepthwiseFilterAxis::Width]);
+
+  return out;
+}
+
+bool PermutingDecoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); }
+
 } // namespace loco
 
 /**
index 93c6de3..4e090c3 100644 (file)
@@ -491,3 +491,63 @@ TEST(PermutingDecoderTest, filter)
   ASSERT_EQ(filter_index.row(), 2);
   ASSERT_EQ(filter_index.column(), 3);
 }
+
+TEST(PermutingDecoderTest, depthwise_filter)
+{
+  PermutingDecoder<Domain::DepthwiseFilter> dec;
+
+  // Decoder is invalid at the beginning
+  ASSERT_FALSE(dec.valid());
+
+  // Set "invalid" mapping
+  dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
+  dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
+  dec.perm()->axis(DepthwiseFilterAxis::Height) = 1;
+  dec.perm()->axis(DepthwiseFilterAxis::Width) = 2;
+
+  // Decoder is still invalid
+  ASSERT_FALSE(dec.valid());
+
+  // Set another "invalid" mapping
+  dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
+
+  // Decoder is still invalid
+  ASSERT_FALSE(dec.valid());
+
+  // Set "valid" mapping
+  dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
+
+  // Decoder is now valid
+  ASSERT_TRUE(dec.valid());
+
+  DepthwiseFilterShape dw_filter_shape;
+
+  dw_filter_shape.depth() = 8;
+  dw_filter_shape.multiplier() = 1;
+  dw_filter_shape.height() = 7;
+  dw_filter_shape.width() = 4;
+
+  // Get the corresponding depthwise filter shape
+  auto tensor_shape = dec.shape(dw_filter_shape);
+
+  ASSERT_EQ(tensor_shape.dim(0).value(), 8);
+  ASSERT_EQ(tensor_shape.dim(1).value(), 7);
+  ASSERT_EQ(tensor_shape.dim(2).value(), 4);
+  ASSERT_EQ(tensor_shape.dim(3).value(), 1);
+
+  // Let's find a source tensor index!
+  TensorIndex tensor_index;
+  tensor_index.resize(4);
+
+  tensor_index.at(0) = 4;
+  tensor_index.at(1) = 2;
+  tensor_index.at(2) = 1;
+  tensor_index.at(3) = 0;
+
+  auto dw_filter_index = dec.value(tensor_index);
+
+  ASSERT_EQ(dw_filter_index.channel(), 4);
+  ASSERT_EQ(dw_filter_index.nth(), 0);
+  ASSERT_EQ(dw_filter_index.row(), 2);
+  ASSERT_EQ(dw_filter_index.column(), 1);
+}