[mir2loco] Add support for AvgPool2D and MaxPool2D (#7022)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 29 Aug 2019 21:33:27 +0000 (06:33 +0900)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 29 Aug 2019 21:33:27 +0000 (00:33 +0300)
These operations are the future replacement of the `Pool` operation.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir2loco/include/mir2loco.h
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index bb31dcb..56099f5 100644 (file)
@@ -27,11 +27,13 @@ public:
   ~Transformer() = default;
 
   void visit(mir::ops::AddOp &op) override;
+  void visit(mir::ops::AvgPool2DOp &op) override;
   void visit(mir::ops::ConcatOp &op) override;
   void visit(mir::ops::ConstantOp &op) override;
   void visit(mir::ops::Conv2DOp &op) override;
   void visit(mir::ops::DepthwiseConv2DOp &op) override;
   void visit(mir::ops::InputOp &op) override;
+  void visit(mir::ops::MaxPool2DOp &op) override;
   void visit(mir::ops::MulOp &op) override;
   void visit(mir::ops::OutputOp &op) override;
   void visit(mir::ops::PoolOp &op) override;
index a6ca8d1..f8bc0cc 100644 (file)
 #include "mir2loco.h"
 
 #include "mir/ops/AddOp.h"
+#include "mir/ops/AvgPool2DOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/MaxPool2DOp.h"
 #include "mir/ops/MulOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
@@ -53,12 +55,10 @@ std::unique_ptr<loco::TensorShape> make_tensor_shape(const mir::Shape &shape)
   return std::move(res);
 }
 
-void setupPad(const std::vector<int32_t> &padding_before, const std::vector<int32_t> &padding_after,
-              loco::Pad<2> *pad)
+void setupPad(const std::vector<std::int32_t> &padding_before,
+              const std::vector<std::int32_t> &padding_after, loco::Pad<2> *pad)
 {
-  if (padding_before.size() != 2 || padding_after.size() != 2)
-    throw std::runtime_error("Support only 2D paddings!");
-
+  assert(padding_before.size() == 2 && padding_after.size() == 2);
   pad->left(padding_before.at(0));
   pad->top(padding_before.at(1));
   pad->right(padding_after.at(0));
@@ -67,22 +67,32 @@ void setupPad(const std::vector<int32_t> &padding_before, const std::vector<int3
 
 void setupWindow(const mir::Shape &window_shape, loco::Window<2> *window)
 {
-  if (window_shape.rank() != 2)
-    throw std::runtime_error("Support only 2D window size!");
-
+  assert(window_shape.rank() == 2);
   window->horizontal(window_shape.dim(0));
   window->vertical(window_shape.dim(1));
 }
 
-void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride)
+void setupWindow(const std::vector<std::int32_t> &window_size, loco::Window<2> *window)
 {
-  if (stride_shape.rank() != 2)
-    throw std::runtime_error("Support only 2D strides!");
+  assert(window_size.size() == 2);
+  window->horizontal(window_size[0]);
+  window->vertical(window_size[1]);
+}
 
+void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride)
+{
+  assert(stride_shape.rank() == 2);
   stride->horizontal(stride_shape.dim(0));
   stride->vertical(stride_shape.dim(1));
 }
 
+void setupStride(const std::vector<std::int32_t> &strides, loco::Stride<2> *stride)
+{
+  assert(strides.size() == 2);
+  stride->horizontal(strides[0]);
+  stride->vertical(strides[1]);
+}
+
 loco::FeatureEncode *createNHWCFeatureEncode(loco::Graph *graph)
 {
   auto encode_node = graph->nodes()->create<loco::FeatureEncode>();
@@ -165,6 +175,35 @@ void Transformer::visit(mir::ops::AddOp &op)
   _mir2loco_map.emplace(&op, result);
 }
 
+void Transformer::visit(mir::ops::AvgPool2DOp &op)
+{
+  // Get Input
+  auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+  assert(loco_it != _mir2loco_map.end()); // can't find the input
+  // FeatureEncode
+  auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+  encode_node->input(loco_it->second);
+
+  auto avg_pool_node = _loco_graph->nodes()->create<loco::AvgPool2D>();
+  // Set Input
+  avg_pool_node->ifm(encode_node);
+  // Set convention (like tensorflow)
+  avg_pool_node->convention(op.getIncludePad() ? loco::AvgPool2D::Convention::Full
+                                               : loco::AvgPool2D::Convention::Valid);
+
+  setupWindow(op.getWindowSize(), avg_pool_node->window());
+  setupStride(op.getStrides(), avg_pool_node->stride());
+  setupPad(op.getPaddingBefore(), op.getPaddingAfter(), avg_pool_node->pad());
+
+  // FeatureDecode
+  auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+  // Set Input
+  decode_node->input(avg_pool_node);
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, decode_node);
+}
+
 void Transformer::visit(mir::ops::ConcatOp &op)
 {
   if (op.getNumInputs() < 2)
@@ -370,6 +409,32 @@ void Transformer::visit(mir::ops::InputOp &op)
   _mir2loco_map.emplace(&op, pull_node);
 }
 
+void Transformer::visit(mir::ops::MaxPool2DOp &op)
+{
+  // Get Input
+  auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+  assert(loco_it != _mir2loco_map.end()); // can't find the input
+  // FeatureEncode
+  auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+  encode_node->input(loco_it->second);
+
+  auto max_pool_node = _loco_graph->nodes()->create<loco::MaxPool2D>();
+  // Set Input
+  max_pool_node->ifm(encode_node);
+
+  setupWindow(op.getWindowSize(), max_pool_node->window());
+  setupStride(op.getStrides(), max_pool_node->stride());
+  setupPad(op.getPaddingBefore(), op.getPaddingAfter(), max_pool_node->pad());
+
+  // FeatureDecode
+  auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+  // Set Input
+  decode_node->input(max_pool_node);
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, decode_node);
+}
+
 void Transformer::visit(mir::ops::MulOp &op)
 {
   // Get Input
index 6725d0c..cb6af05 100644 (file)
 #include "mir2loco.h"
 
 #include "mir/ops/AddOp.h"
+#include "mir/ops/AvgPool2DOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/MaxPool2DOp.h"
 #include "mir/ops/MulOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
@@ -121,10 +123,10 @@ TEST_F(TestTransformer_mir2loco, Avg_Pool_Test)
   mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
   auto *input = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
   auto *pool = mir_graph
-                   .create<mir::ops::PoolOp>(
-                       input, mir::ops::PoolOp::PoolingType::AVG, mir::Shape{2, 3},
-                       mir::Shape{4, 5}, std::vector<std::int32_t>{5, 9},
-                       std::vector<std::int32_t>{7, 4}, mir::ops::PoolOp::BorderType::EMPTY)
+                   .create<mir::ops::AvgPool2DOp>(
+                       input, std::vector<std::int32_t>{2, 3}, std::vector<std::int32_t>{4, 5},
+                       std::vector<std::int32_t>{5, 9}, std::vector<std::int32_t>{7, 4}, true,
+                       mir::DataFormat::NHWC)
                    ->getOutput(0);
   mir_graph.create<mir::ops::OutputOp>(pool);
   input->setName("x");
@@ -151,7 +153,7 @@ TEST_F(TestTransformer_mir2loco, Avg_Pool_Test)
   ASSERT_EQ(decode_node->input(), pool_node);
   ASSERT_EQ(push_node->from(), decode_node);
   // Check params
-  ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Valid);
+  ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Full);
   ASSERT_EQ(pool_node->pad()->left(), 5);
   ASSERT_EQ(pool_node->pad()->top(), 9);
   ASSERT_EQ(pool_node->pad()->right(), 7);
@@ -169,10 +171,10 @@ TEST_F(TestTransformer_mir2loco, Max_Pool_Test)
   mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
   auto *input = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
   auto *pool = mir_graph
-                   .create<mir::ops::PoolOp>(
-                       input, mir::ops::PoolOp::PoolingType::MAX, mir::Shape{2, 3},
-                       mir::Shape{4, 5}, std::vector<std::int32_t>{5, 9},
-                       std::vector<std::int32_t>{7, 4}, mir::ops::PoolOp::BorderType::EMPTY)
+                   .create<mir::ops::MaxPool2DOp>(
+                       input, std::vector<std::int32_t>{2, 3}, std::vector<std::int32_t>{4, 5},
+                       std::vector<std::int32_t>{5, 9}, std::vector<std::int32_t>{7, 4},
+                       mir::DataFormat::NHWC)
                    ->getOutput(0);
   mir_graph.create<mir::ops::OutputOp>(pool);
   input->setName("x");