[loco] Introducing PermutingDecoder for filter (#7801)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Fri, 27 Sep 2019 08:22:28 +0000 (17:22 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 27 Sep 2019 08:22:28 +0000 (17:22 +0900)
* [loco] Introducing PermutingDecoder for filter

This adds PermutingDecoder for filter and related test case.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* move PermutingDecoder from DepthwiseFilter area to filter area

compiler/loco/include/loco/IR/PermutingCodec.h
compiler/loco/src/IR/PermutingCodec.cpp
compiler/loco/src/IR/PermutingCodec.test.cpp

index ea0ecaa..71a2a65 100644 (file)
@@ -191,6 +191,30 @@ private:
 };
 
 /**
+ * @brief Permutation-based Filter-to-Tensor converter
+ */
+template <> class PermutingDecoder<Domain::Filter> final : public FilterDecoder
+{
+public:
+  PermutingDecoder() = default;
+
+public:
+  bool valid(void) const;
+
+public:
+  TensorShape shape(const FilterShape &tensor_shape) const override;
+  FilterIndex value(const TensorIndex &index) const override;
+
+public:
+  const Permutation<Domain::Filter> *perm(void) const { return &_perm; }
+  Permutation<Domain::Filter> *perm(void) { return &_perm; }
+  void perm(const Permutation<Domain::Filter> &p) { _perm = p; }
+
+private:
+  Permutation<Domain::Filter> _perm;
+};
+
+/**
  * @brief Mapping between DepthwiseFilter/Tensor Axis
  */
 template <> class Permutation<Domain::DepthwiseFilter>
index ba76963..5d8156f 100644 (file)
@@ -303,6 +303,40 @@ TensorIndex PermutingEncoder<Domain::Filter>::value(const FilterIndex &in) const
 
 bool PermutingEncoder<Domain::Filter>::valid(void) const { return ::valid(_perm); }
 
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::Filter>::shape(const FilterShape &in) const
+{
+  assert(valid() && "invalid permutation");
+
+  TensorShape out;
+
+  out.rank(4);
+  out.dim(_perm[FilterAxis::Count]) = in.count();
+  out.dim(_perm[FilterAxis::Depth]) = in.depth();
+  out.dim(_perm[FilterAxis::Height]) = in.height();
+  out.dim(_perm[FilterAxis::Width]) = in.width();
+
+  return out;
+}
+
+FilterIndex PermutingDecoder<Domain::Filter>::value(const TensorIndex &in) const
+{
+  assert(valid() && "invalid permutation");
+
+  FilterIndex out;
+
+  out.nth() = in.at(_perm[FilterAxis::Count]);
+  out.channel() = in.at(_perm[FilterAxis::Depth]);
+  out.row() = in.at(_perm[FilterAxis::Height]);
+  out.column() = in.at(_perm[FilterAxis::Width]);
+
+  return out;
+}
+
+bool PermutingDecoder<Domain::Filter>::valid(void) const { return ::valid(_perm); }
+
 } // namespace loco
 
 /**
index eb296d2..93c6de3 100644 (file)
@@ -428,3 +428,66 @@ TEST(PermutingDecoderTest, feature_clone)
   EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
   EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
 }
+
+TEST(PermutingDecoderTest, filter)
+{
+  PermutingDecoder<Domain::Filter> dec;
+
+  // Decoder is invalid at the beginning
+  ASSERT_FALSE(dec.valid());
+
+  // Set "invalid" mapping
+  dec.perm()->axis(FilterAxis::Count) = 0;
+  dec.perm()->axis(FilterAxis::Depth) = 6;
+  dec.perm()->axis(FilterAxis::Height) = 1;
+  dec.perm()->axis(FilterAxis::Width) = 2;
+
+  // Decoder is still invalid
+  ASSERT_FALSE(dec.valid());
+
+  // Set another "invalid" mapping
+  dec.perm()->axis(FilterAxis::Depth) = 1;
+
+  // Decoder is still invalid
+  ASSERT_FALSE(dec.valid());
+
+  // Set "valid" mapping
+  dec.perm()->axis(FilterAxis::Depth) = 3;
+
+  // Decoder is now valid
+  ASSERT_TRUE(dec.valid());
+
+  // Let's test with a small filter
+  FilterShape filter_shape;
+
+  filter_shape.count() = 10;
+  filter_shape.depth() = 3;
+  filter_shape.height() = 6;
+  filter_shape.width() = 8;
+
+  // Get the tensor shape corresponding to a given image
+  auto tensor_shape = dec.shape(filter_shape);
+
+  ASSERT_EQ(tensor_shape.rank(), 4);
+  ASSERT_EQ(tensor_shape.dim(0), 10); // COUNT
+  ASSERT_EQ(tensor_shape.dim(1), 6);  // HEIGHT
+  ASSERT_EQ(tensor_shape.dim(2), 8);  // WIDTH
+  ASSERT_EQ(tensor_shape.dim(3), 3);  // DEPTH
+
+  // Let's find a source filter index!
+  TensorIndex tensor_index;
+
+  tensor_index.resize(4);
+
+  tensor_index.at(0) = 0; // BATCH(COUNT)
+  tensor_index.at(3) = 1; // CHANNEL(DEPTH)
+  tensor_index.at(1) = 2; // ROW(HEIGHT)
+  tensor_index.at(2) = 3; // COLUMN(WIDTH)
+
+  auto filter_index = dec.value(tensor_index);
+
+  ASSERT_EQ(filter_index.nth(), 0);
+  ASSERT_EQ(filter_index.channel(), 1);
+  ASSERT_EQ(filter_index.row(), 2);
+  ASSERT_EQ(filter_index.column(), 3);
+}