[mir2loco] Remove Push node's shape setter (#6487)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 12 Aug 2019 17:23:51 +0000 (02:23 +0900)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Mon, 12 Aug 2019 17:23:51 +0000 (20:23 +0300)
This commit revises all the code that invokes deprecated shape setter.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index b9e9f32..b1c7de5 100644 (file)
@@ -43,6 +43,13 @@ template <class NodeType> void setupShape(const mir::Shape &shape, NodeType *nod
   }
 }
 
+std::unique_ptr<loco::TensorShape> make_tensor_shape(const mir::Shape &shape)
+{
+  auto res = stdex::make_unique<loco::TensorShape>();
+  setupShape(shape, res.get());
+  return std::move(res);
+}
+
 void setupPad(const std::vector<int32_t> &padding_before, const std::vector<int32_t> &padding_after,
               loco::Pad<2> *pad)
 {
@@ -315,13 +322,13 @@ void Transformer::visit(mir::ops::OutputOp &op)
   auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
   assert(loco_it != _mir2loco_map.end()); // can't find the input
   push_node->from(loco_it->second);
-  // Set Shape
-  const auto &out_shape = op.getInput(0)->getProducer()->getShape();
-  setupShape(out_shape, push_node);
   // Set graph output
   auto graph_output = _loco_graph->outputs()->create();
   graph_output->name(op.getName());
   graph_output->dtype(loco::DataType::FLOAT32); // TODO Support other types
+  // Set graph output shape
+  const auto &out_shape = op.getInput(0)->getProducer()->getShape();
+  graph_output->shape(make_tensor_shape(out_shape));
   loco::link(graph_output, push_node);
   // Add to map
   _mir2loco_map.emplace(&op, push_node);
index 5ed3e23..0c79ecc 100644 (file)
@@ -54,11 +54,17 @@ TEST_F(TestTransformer_mir2loco, Input_Output_Test)
   ASSERT_EQ(pull_node->dim(2), 7);
   ASSERT_EQ(pull_node->dim(3), 8);
 
-  ASSERT_EQ(push_node->rank(), 4);
-  ASSERT_EQ(push_node->dim(0), 5);
-  ASSERT_EQ(push_node->dim(1), 6);
-  ASSERT_EQ(push_node->dim(2), 7);
-  ASSERT_EQ(push_node->dim(3), 8);
+  ASSERT_TRUE(push_node->indexed());
+  ASSERT_EQ(push_node->index(), 0);
+
+  // Check Graph-level properties
+  ASSERT_EQ(loco_graph->outputs()->size(), 1);
+  ASSERT_NE(loco_graph->outputs()->at(0)->shape(), nullptr);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->rank(), 4);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(0), 5);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(1), 6);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(2), 7);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(3), 8);
 }
 
 TEST_F(TestTransformer_mir2loco, Relu_Test)
@@ -89,11 +95,17 @@ TEST_F(TestTransformer_mir2loco, Relu_Test)
   ASSERT_EQ(pull_node->dim(2), 9);
   ASSERT_EQ(pull_node->dim(3), 9);
 
-  ASSERT_EQ(push_node->rank(), 4);
-  ASSERT_EQ(push_node->dim(0), 7);
-  ASSERT_EQ(push_node->dim(1), 7);
-  ASSERT_EQ(push_node->dim(2), 9);
-  ASSERT_EQ(push_node->dim(3), 9);
+  ASSERT_TRUE(push_node->indexed());
+  ASSERT_EQ(push_node->index(), 0);
+
+  // Check Graph-level properties
+  ASSERT_EQ(loco_graph->outputs()->size(), 1);
+  ASSERT_NE(loco_graph->outputs()->at(0)->shape(), nullptr);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->rank(), 4);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(0), 7);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(1), 7);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(2), 9);
+  ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(3), 9);
 }
 
 TEST_F(TestTransformer_mir2loco, Avg_Pool_Test)