From 68a80c937392f77cc148f1f67e2f8fe3a3f44b04 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 13 Aug 2019 02:23:51 +0900 Subject: [PATCH] [mir2loco] Remove Push node's shape setter (#6487) This commit revises all the code that invokes deprecated shape setter. Signed-off-by: Jonghyun Park --- compiler/mir2loco/src/mir2loco.cpp | 13 ++++++++++--- compiler/mir2loco/src/mir2loco.test.cpp | 32 ++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index b9e9f32..b1c7de5 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -43,6 +43,13 @@ template void setupShape(const mir::Shape &shape, NodeType *nod } } +std::unique_ptr make_tensor_shape(const mir::Shape &shape) +{ + auto res = stdex::make_unique(); + setupShape(shape, res.get()); + return std::move(res); +} + void setupPad(const std::vector &padding_before, const std::vector &padding_after, loco::Pad<2> *pad) { @@ -315,13 +322,13 @@ void Transformer::visit(mir::ops::OutputOp &op) auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode()); assert(loco_it != _mir2loco_map.end()); // can't find the input push_node->from(loco_it->second); - // Set Shape - const auto &out_shape = op.getInput(0)->getProducer()->getShape(); - setupShape(out_shape, push_node); // Set graph output auto graph_output = _loco_graph->outputs()->create(); graph_output->name(op.getName()); graph_output->dtype(loco::DataType::FLOAT32); // TODO Support other types + // Set graph output shape + const auto &out_shape = op.getInput(0)->getProducer()->getShape(); + graph_output->shape(make_tensor_shape(out_shape)); loco::link(graph_output, push_node); // Add to map _mir2loco_map.emplace(&op, push_node); diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index 5ed3e23..0c79ecc 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -54,11 +54,17 @@ TEST_F(TestTransformer_mir2loco, Input_Output_Test) ASSERT_EQ(pull_node->dim(2), 7); ASSERT_EQ(pull_node->dim(3), 8); - ASSERT_EQ(push_node->rank(), 4); - ASSERT_EQ(push_node->dim(0), 5); - ASSERT_EQ(push_node->dim(1), 6); - ASSERT_EQ(push_node->dim(2), 7); - ASSERT_EQ(push_node->dim(3), 8); + ASSERT_TRUE(push_node->indexed()); + ASSERT_EQ(push_node->index(), 0); + + // Check Graph-level properties + ASSERT_EQ(loco_graph->outputs()->size(), 1); + ASSERT_NE(loco_graph->outputs()->at(0)->shape(), nullptr); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->rank(), 4); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(0), 5); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(1), 6); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(2), 7); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(3), 8); } TEST_F(TestTransformer_mir2loco, Relu_Test) @@ -89,11 +95,17 @@ TEST_F(TestTransformer_mir2loco, Relu_Test) ASSERT_EQ(pull_node->dim(2), 9); ASSERT_EQ(pull_node->dim(3), 9); - ASSERT_EQ(push_node->rank(), 4); - ASSERT_EQ(push_node->dim(0), 7); - ASSERT_EQ(push_node->dim(1), 7); - ASSERT_EQ(push_node->dim(2), 9); - ASSERT_EQ(push_node->dim(3), 9); + ASSERT_TRUE(push_node->indexed()); + ASSERT_EQ(push_node->index(), 0); + + // Check Graph-level properties + ASSERT_EQ(loco_graph->outputs()->size(), 1); + ASSERT_NE(loco_graph->outputs()->at(0)->shape(), nullptr); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->rank(), 4); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(0), 7); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(1), 7); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(2), 9); + ASSERT_EQ(loco_graph->outputs()->at(0)->shape()->dim(3), 9); } TEST_F(TestTransformer_mir2loco, Avg_Pool_Test) -- 2.7.4