From 010d69e2b74c8e3d8626b1f64c91de5935bd5eeb Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Thu, 1 Aug 2019 17:01:45 +0300
Subject: [PATCH] Support mir::ReshapeOp transformation to loco::Reshape
(#5952)
* Added support to transformer of reshape operation
* Added test for Reshape
Signed-off-by: Pavel Iliutchenko
---
compiler/mir2loco/src/mir2loco.cpp | 15 ++++++++++++++-
compiler/mir2loco/src/mir2loco.test.cpp | 30 ++++++++++++++++++++++++++++++
2 files changed, 44 insertions(+), 1 deletion(-)
diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp
index fe357b4..b2031b7 100644
--- a/compiler/mir2loco/src/mir2loco.cpp
+++ b/compiler/mir2loco/src/mir2loco.cpp
@@ -19,6 +19,7 @@
#include "mir/ops/ConcatOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
+#include "mir/ops/ReshapeOp.h"
#include
@@ -263,7 +264,19 @@ void Transformer::visit(mir::ops::ReluOp &op)
_mir2loco_map.emplace(&op, relu_node);
}
-void Transformer::visit(mir::ops::ReshapeOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::ReshapeOp &op)
+{
+ auto reshape_node = _loco_graph->nodes()->create>();
+ // Set Input
+ auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+ assert(loco_it != _mir2loco_map.end()); // can't find the input
+ reshape_node->input(loco_it->second);
+ // Set Shape
+ auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, reshape_node);
+ // Add to map
+ _mir2loco_map.emplace(&op, reshape_node);
+}
void Transformer::visit(mir::ops::ResizeOp &op) { throw std::runtime_error("NYI"); }
diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp
index e2aa9cb..b274644 100644
--- a/compiler/mir2loco/src/mir2loco.test.cpp
+++ b/compiler/mir2loco/src/mir2loco.test.cpp
@@ -19,6 +19,7 @@
#include "mir/ops/ConcatOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
+#include "mir/ops/ReshapeOp.h"
#include
@@ -218,3 +219,32 @@ TEST_F(TestTransformer_mir2loco, Concat_Test)
ASSERT_EQ(concat1_node->axis(), 2);
ASSERT_EQ(concat2_node->axis(), 2);
}
+
+TEST_F(TestTransformer_mir2loco, Reshape_Test)
+{
+ mir::Graph mir_graph;
+
+ auto *input = mir_graph.create("input", mir::Shape{7, 8, 9, 9});
+ auto *reshape =
+ mir_graph.create("reshape", input->getOutput(0), mir::Shape{7, 8, 81});
+ auto *output = mir_graph.create("output", reshape->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ loco::Pull *pull_node = dynamic_cast(loco_graph->nodes()->at(0));
+ loco::Reshape *reshape_node =
+ dynamic_cast *>(loco_graph->nodes()->at(1));
+ loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(2));
+
+ ASSERT_NE(pull_node, nullptr);
+ ASSERT_NE(reshape_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(reshape_node->input(), pull_node);
+ ASSERT_EQ(push_node->from(), reshape_node);
+ // Check params
+ ASSERT_EQ(reshape_node->rank(), 3);
+ ASSERT_EQ(reshape_node->dim(0), 7);
+ ASSERT_EQ(reshape_node->dim(1), 8);
+ ASSERT_EQ(reshape_node->dim(2), 81);
+}
--
2.7.4