From 010d69e2b74c8e3d8626b1f64c91de5935bd5eeb Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 1 Aug 2019 17:01:45 +0300 Subject: [PATCH] Support mir::ReshapeOp transformation to loco::Reshape (#5952) * Added support to transformer of reshape operation * Added test for Reshape Signed-off-by: Pavel Iliutchenko --- compiler/mir2loco/src/mir2loco.cpp | 15 ++++++++++++++- compiler/mir2loco/src/mir2loco.test.cpp | 30 ++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index fe357b4..b2031b7 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -19,6 +19,7 @@ #include "mir/ops/ConcatOp.h" #include "mir/ops/PoolOp.h" #include "mir/ops/ReluOp.h" +#include "mir/ops/ReshapeOp.h" #include @@ -263,7 +264,19 @@ void Transformer::visit(mir::ops::ReluOp &op) _mir2loco_map.emplace(&op, relu_node); } -void Transformer::visit(mir::ops::ReshapeOp &op) { throw std::runtime_error("NYI"); } +void Transformer::visit(mir::ops::ReshapeOp &op) +{ + auto reshape_node = _loco_graph->nodes()->create>(); + // Set Input + auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode()); + assert(loco_it != _mir2loco_map.end()); // can't find the input + reshape_node->input(loco_it->second); + // Set Shape + auto &out_shape = op.getOutputShape(0); + setupShape(out_shape, reshape_node); + // Add to map + _mir2loco_map.emplace(&op, reshape_node); +} void Transformer::visit(mir::ops::ResizeOp &op) { throw std::runtime_error("NYI"); } diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index e2aa9cb..b274644 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -19,6 +19,7 @@ #include "mir/ops/ConcatOp.h" #include "mir/ops/PoolOp.h" #include "mir/ops/ReluOp.h" +#include "mir/ops/ReshapeOp.h" #include @@ -218,3 +219,32 @@ TEST_F(TestTransformer_mir2loco, Concat_Test) ASSERT_EQ(concat1_node->axis(), 2); ASSERT_EQ(concat2_node->axis(), 2); } + +TEST_F(TestTransformer_mir2loco, Reshape_Test) +{ + mir::Graph mir_graph; + + auto *input = mir_graph.create("input", mir::Shape{7, 8, 9, 9}); + auto *reshape = + mir_graph.create("reshape", input->getOutput(0), mir::Shape{7, 8, 81}); + auto *output = mir_graph.create("output", reshape->getOutput(0)); + + mir2loco::Transformer transformer; + auto loco_graph = transformer.transform(&mir_graph); + + loco::Pull *pull_node = dynamic_cast(loco_graph->nodes()->at(0)); + loco::Reshape *reshape_node = + dynamic_cast *>(loco_graph->nodes()->at(1)); + loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(2)); + + ASSERT_NE(pull_node, nullptr); + ASSERT_NE(reshape_node, nullptr); + ASSERT_NE(push_node, nullptr); + ASSERT_EQ(reshape_node->input(), pull_node); + ASSERT_EQ(push_node->from(), reshape_node); + // Check params + ASSERT_EQ(reshape_node->rank(), 3); + ASSERT_EQ(reshape_node->dim(0), 7); + ASSERT_EQ(reshape_node->dim(1), 8); + ASSERT_EQ(reshape_node->dim(2), 81); +} -- 2.7.4