[mir2loco] Support Pooling operations (#5846)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Fri, 26 Jul 2019 11:25:42 +0000 (14:25 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 26 Jul 2019 11:25:42 +0000 (20:25 +0900)
* Support Pooling operation and tests for 2 ops

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/CMakeLists.txt
compiler/mir2loco/include/mir2loco.h
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index 9aaf723..b506559 100644 (file)
@@ -7,6 +7,7 @@ target_include_directories(mir2loco PRIVATE src)
 target_include_directories(mir2loco PUBLIC include)
 target_link_libraries(mir2loco PUBLIC mir)
 target_link_libraries(mir2loco PUBLIC loco)
+target_link_libraries(mir2loco PRIVATE stdex)
 
 nncc_find_package(GTest QUIET)
 
index aeda08c..795c424 100644 (file)
@@ -20,7 +20,7 @@
 namespace mir2loco
 {
 
-class Transformer : public mir::Visitor
+class Transformer final : public mir::Visitor
 {
 public:
   Transformer() = default;
index 07ca57c..9ae6a92 100644 (file)
 
 #include "mir2loco.h"
 
+#include "mir/ops/PoolOp.h"
+
+#include <stdex/Memory.h>
+
 namespace mir2loco
 {
-
+namespace
+{
 template <class NodeType> void setupShape(const mir::Shape &shape, NodeType *node)
 {
   node->rank(shape.rank());
@@ -28,6 +33,67 @@ template <class NodeType> void setupShape(const mir::Shape &shape, NodeType *nod
   }
 }
 
+void setupPad(const std::vector<int32_t> &padding_before, const std::vector<int32_t> &padding_after,
+              loco::Pad<2> *pad)
+{
+  if (padding_before.size() != 2 || padding_after.size() != 2)
+    throw std::runtime_error("Support only 2D paddings!");
+
+  pad->left(padding_before.at(0));
+  pad->top(padding_before.at(1));
+  pad->right(padding_after.at(0));
+  pad->bottom(padding_after.at(1));
+}
+
+void setupWindow(const mir::Shape &window_shape, loco::Window<2> *window)
+{
+  if (window_shape.rank() != 2)
+    throw std::runtime_error("Support only 2D window size!");
+
+  window->horizontal(window_shape.dim(0));
+  window->vertical(window_shape.dim(1));
+}
+
+void setupStride(const mir::Shape &stride_shape, loco::Stride<2> *stride)
+{
+  if (stride_shape.rank() != 2)
+    throw std::runtime_error("Support only 2D strides!");
+
+  stride->horizontal(stride_shape.dim(0));
+  stride->vertical(stride_shape.dim(1));
+}
+
+loco::FeatureEncode *createNHWCFeatureEncode(loco::Graph *graph)
+{
+  auto encode_node = graph->nodes()->create<loco::FeatureEncode>();
+
+  auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+  enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+  enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+  enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+  enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+
+  encode_node->encoder(std::move(enc));
+  return encode_node;
+}
+
+loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph)
+{
+  auto decode_node = graph->nodes()->create<loco::FeatureDecode>();
+
+  auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+  dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+  dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+  dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+  dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+
+  decode_node->decoder(std::move(dec));
+  return decode_node;
+}
+} // namespace
+
 void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
 
 void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); }
@@ -97,7 +163,59 @@ void Transformer::visit(mir::ops::OutputOp &op)
 
 void Transformer::visit(mir::ops::PadOp &op) { throw std::runtime_error("NYI"); }
 
-void Transformer::visit(mir::ops::PoolOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::PoolOp &op)
+{
+  loco::Node *pool_node;
+  // Get Input
+  auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+  assert(loco_it != _mir2loco_map.end()); // can't find the input
+  // FeatureEncode
+  auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+  encode_node->input(loco_it->second);
+
+  switch (op.getPoolingType())
+  {
+    case mir::ops::PoolOp::PoolingType::AVG:
+    {
+      auto avg_pool_node = _loco_graph->nodes()->create<loco::AvgPool2D>();
+      // Set Input
+      avg_pool_node->ifm(encode_node);
+      // Set convention (like tensorflow)
+      avg_pool_node->convention(loco::AvgPool2D::Convention::Valid);
+
+      setupWindow(op.getWindowShape(), avg_pool_node->window());
+      setupStride(op.getStrides(), avg_pool_node->stride());
+      setupPad(op.getPaddingBefore(), op.getPaddingAfter(), avg_pool_node->pad());
+
+      pool_node = avg_pool_node;
+      break;
+    }
+    case mir::ops::PoolOp::PoolingType::MAX:
+    {
+      auto max_pool_node = _loco_graph->nodes()->create<loco::MaxPool2D>();
+      // Set Input
+      max_pool_node->ifm(encode_node);
+
+      setupWindow(op.getWindowShape(), max_pool_node->window());
+      setupStride(op.getStrides(), max_pool_node->stride());
+      setupPad(op.getPaddingBefore(), op.getPaddingAfter(), max_pool_node->pad());
+
+      pool_node = max_pool_node;
+      break;
+    }
+    case mir::ops::PoolOp::PoolingType::MIN:
+      throw std::runtime_error("Min pooling not supported!");
+    default:
+      assert(false && "Unknown pooling type!");
+  }
+  // FeatureDecode
+  auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+  // Set Input
+  decode_node->input(pool_node);
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, decode_node);
+}
 
 void Transformer::visit(mir::ops::ReduceOp &op) { throw std::runtime_error("NYI"); }
 
index c6e7263..a0485d2 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "mir2loco.h"
 
+#include "mir/ops/PoolOp.h"
+
 #include <gtest/gtest.h>
 
 class TestTransformer_mir2loco : public ::testing::Test
@@ -53,3 +55,90 @@ TEST_F(TestTransformer_mir2loco, Input_Output_Test)
   ASSERT_EQ(push_node->dim(2), 7);
   ASSERT_EQ(push_node->dim(3), 8);
 }
+
+TEST_F(TestTransformer_mir2loco, Avg_Pool_Test)
+{
+  mir::Graph mir_graph;
+
+  mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
+  auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+  auto *pool = mir_graph.create<mir::ops::PoolOp>(
+      "pool", input->getOutput(0), mir::ops::PoolOp::PoolingType::AVG, mir::Shape{2, 3},
+      mir::Shape{4, 5}, std::vector<int32_t>{5, 9}, std::vector<int32_t>{7, 4},
+      mir::ops::PoolOp::BorderType::EMPTY);
+  auto *output = mir_graph.create<mir::ops::OutputOp>("output", pool->getOutput(0));
+
+  mir2loco::Transformer transformer;
+  auto loco_graph = transformer.transform(&mir_graph);
+
+  loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+  loco::FeatureEncode *encode_node =
+      dynamic_cast<loco::FeatureEncode *>(loco_graph->nodes()->at(1));
+  loco::AvgPool2D *pool_node = dynamic_cast<loco::AvgPool2D *>(loco_graph->nodes()->at(2));
+  loco::FeatureDecode *decode_node =
+      dynamic_cast<loco::FeatureDecode *>(loco_graph->nodes()->at(3));
+  loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(4));
+
+  ASSERT_NE(pull_node, nullptr);
+  ASSERT_NE(encode_node, nullptr);
+  ASSERT_NE(pool_node, nullptr);
+  ASSERT_NE(decode_node, nullptr);
+  ASSERT_NE(push_node, nullptr);
+  ASSERT_EQ(encode_node->input(), pull_node);
+  ASSERT_EQ(pool_node->ifm(), encode_node);
+  ASSERT_EQ(decode_node->input(), pool_node);
+  ASSERT_EQ(push_node->from(), decode_node);
+  // Check params
+  ASSERT_EQ(pool_node->convention(), loco::AvgPool2D::Convention::Valid);
+  ASSERT_EQ(pool_node->pad()->left(), 5);
+  ASSERT_EQ(pool_node->pad()->top(), 9);
+  ASSERT_EQ(pool_node->pad()->right(), 7);
+  ASSERT_EQ(pool_node->pad()->bottom(), 4);
+  ASSERT_EQ(pool_node->window()->horizontal(), 2);
+  ASSERT_EQ(pool_node->window()->vertical(), 3);
+  ASSERT_EQ(pool_node->stride()->horizontal(), 4);
+  ASSERT_EQ(pool_node->stride()->vertical(), 5);
+}
+
+TEST_F(TestTransformer_mir2loco, Max_Pool_Test)
+{
+  mir::Graph mir_graph;
+
+  mir::Shape input_shape = mir::Shape({7, 7, 9, 9});
+  auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+  auto *pool = mir_graph.create<mir::ops::PoolOp>(
+      "pool", input->getOutput(0), mir::ops::PoolOp::PoolingType::MAX, mir::Shape{2, 3},
+      mir::Shape{4, 5}, std::vector<int32_t>{5, 9}, std::vector<int32_t>{7, 4},
+      mir::ops::PoolOp::BorderType::EMPTY);
+  auto *output = mir_graph.create<mir::ops::OutputOp>("output", pool->getOutput(0));
+
+  mir2loco::Transformer transformer;
+  auto loco_graph = transformer.transform(&mir_graph);
+
+  loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+  loco::FeatureEncode *encode_node =
+      dynamic_cast<loco::FeatureEncode *>(loco_graph->nodes()->at(1));
+  loco::MaxPool2D *pool_node = dynamic_cast<loco::MaxPool2D *>(loco_graph->nodes()->at(2));
+  loco::FeatureDecode *decode_node =
+      dynamic_cast<loco::FeatureDecode *>(loco_graph->nodes()->at(3));
+  loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(4));
+
+  ASSERT_NE(pull_node, nullptr);
+  ASSERT_NE(encode_node, nullptr);
+  ASSERT_NE(pool_node, nullptr);
+  ASSERT_NE(decode_node, nullptr);
+  ASSERT_NE(push_node, nullptr);
+  ASSERT_EQ(encode_node->input(), pull_node);
+  ASSERT_EQ(pool_node->ifm(), encode_node);
+  ASSERT_EQ(decode_node->input(), pool_node);
+  ASSERT_EQ(push_node->from(), decode_node);
+  // Check params
+  ASSERT_EQ(pool_node->pad()->left(), 5);
+  ASSERT_EQ(pool_node->pad()->top(), 9);
+  ASSERT_EQ(pool_node->pad()->right(), 7);
+  ASSERT_EQ(pool_node->pad()->bottom(), 4);
+  ASSERT_EQ(pool_node->window()->horizontal(), 2);
+  ASSERT_EQ(pool_node->window()->vertical(), 3);
+  ASSERT_EQ(pool_node->stride()->horizontal(), 4);
+  ASSERT_EQ(pool_node->stride()->vertical(), 5);
+}