[mir2loco] Support Conv2D operation (#5899)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Mon, 5 Aug 2019 13:49:31 +0000 (16:49 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 5 Aug 2019 13:49:31 +0000 (16:49 +0300)
* Implemented `mir::Conv2DOp` transformation to `loco::Conv2D`
* Added test for this

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index c005b30..6f70dd7 100644 (file)
@@ -19,6 +19,7 @@
 #include "mir/ops/BiasAddOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
+#include "mir/ops/Conv2DOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
@@ -200,7 +201,52 @@ void Transformer::visit(mir::ops::ConstantOp &op)
   _mir2loco_map.emplace(&op, const_node);
 }
 
-void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::Conv2DOp &op)
+{
+  auto input = op.getInput(0)->getProducer()->getNode();
+  auto kernel = op.getInput(1)->getProducer()->getNode();
+  // Get ConstantOp
+  auto const_node = _mir2loco_map.at(kernel);
+
+  auto filter_enc = _loco_graph->nodes()->create<loco::FilterEncode>();
+  {
+    auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+
+    // mir using filter convention as TF
+    // In TensorFlow, conv2d filter is a 4-D tensor of following shape:
+    // [filter_height, filter_width, in_channels, out_channels] -> HWIO (HWCN)
+    enc->perm()->axis(loco::FilterAxis::Height) = 0;
+    enc->perm()->axis(loco::FilterAxis::Width) = 1;
+    enc->perm()->axis(loco::FilterAxis::Depth) = 2;
+    enc->perm()->axis(loco::FilterAxis::Count) = 3;
+
+    filter_enc->encoder(std::move(enc));
+  }
+  // Set filter input
+  filter_enc->input(const_node);
+  // Setting up conv2d
+
+  // FeatureEncode
+  auto encode_node = createNHWCFeatureEncode(_loco_graph.get());
+  // Set Input
+  auto loco_it = _mir2loco_map.find(input);
+  assert(loco_it != _mir2loco_map.end()); // can't find the input
+  encode_node->input(loco_it->second);
+  // Conv2D
+  auto conv2d_node = _loco_graph->nodes()->create<loco::Conv2D>();
+  setupStride(op.getStrides(), conv2d_node->stride());
+  setupPad(op.getPaddingBefore(), op.getPaddingAfter(), conv2d_node->pad());
+  // Set Input
+  conv2d_node->ifm(encode_node);
+  conv2d_node->ker(filter_enc);
+  // FeatureDecode
+  auto decode_node = createNHWCFeatureDecode(_loco_graph.get());
+  // Set Input
+  decode_node->input(conv2d_node);
+  // Not set Shape
+  // Add to map
+  _mir2loco_map.emplace(&op, decode_node);
+}
 
 void Transformer::visit(mir::ops::DeConv2DOp &op) { throw std::runtime_error("NYI"); }
 
index b8d6770..077ee80 100644 (file)
@@ -19,6 +19,7 @@
 #include "mir/ops/BiasAddOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
+#include "mir/ops/Conv2DOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
@@ -323,3 +324,53 @@ TEST_F(TestTransformer_mir2loco, Bias_Add_Test)
 
   ASSERT_EQ(bias_add_node->axis(), 3);
 }
+
+TEST_F(TestTransformer_mir2loco, Conv2D_Test)
+{
+  mir::Graph mir_graph;
+
+  mir::Shape input_shape = mir::Shape({7, 7, 9, 1});
+  auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+  mir::Shape shape = mir::Shape({2, 3, 1, 1});
+  const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409};
+  auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data);
+  auto *constant = mir_graph.create<mir::ops::ConstantOp>("constant", mir_tensor);
+  auto *conv = mir_graph.create<mir::ops::Conv2DOp>(
+      "conv", input->getOutput(0), constant->getOutput(0), mir::Shape{2, 3},
+      std::vector<int32_t>{5, 9}, std::vector<int32_t>{7, 4});
+  auto *output = mir_graph.create<mir::ops::OutputOp>("output", conv->getOutput(0));
+
+  mir2loco::Transformer transformer;
+  auto loco_graph = transformer.transform(&mir_graph);
+
+  loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+  loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_graph->nodes()->at(1));
+  loco::FilterEncode *filter_node = dynamic_cast<loco::FilterEncode *>(loco_graph->nodes()->at(2));
+  loco::FeatureEncode *encode_node =
+      dynamic_cast<loco::FeatureEncode *>(loco_graph->nodes()->at(3));
+  loco::Conv2D *conv_node = dynamic_cast<loco::Conv2D *>(loco_graph->nodes()->at(4));
+  loco::FeatureDecode *decode_node =
+      dynamic_cast<loco::FeatureDecode *>(loco_graph->nodes()->at(5));
+  loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(6));
+
+  ASSERT_NE(pull_node, nullptr);
+  ASSERT_NE(const_node, nullptr);
+  ASSERT_NE(filter_node, nullptr);
+  ASSERT_NE(encode_node, nullptr);
+  ASSERT_NE(conv_node, nullptr);
+  ASSERT_NE(decode_node, nullptr);
+  ASSERT_NE(push_node, nullptr);
+  ASSERT_EQ(encode_node->input(), pull_node);
+  ASSERT_EQ(filter_node->input(), const_node);
+  ASSERT_EQ(conv_node->ifm(), encode_node);
+  ASSERT_EQ(conv_node->ker(), filter_node);
+  ASSERT_EQ(decode_node->input(), conv_node);
+  ASSERT_EQ(push_node->from(), decode_node);
+  // Check params
+  ASSERT_EQ(conv_node->pad()->left(), 5);
+  ASSERT_EQ(conv_node->pad()->top(), 9);
+  ASSERT_EQ(conv_node->pad()->right(), 7);
+  ASSERT_EQ(conv_node->pad()->bottom(), 4);
+  ASSERT_EQ(conv_node->stride()->horizontal(), 2);
+  ASSERT_EQ(conv_node->stride()->vertical(), 3);
+}