[mir2loco] Introduce Transpose op transformer (#8871)
authorPavel Iliutchenko/AI Tools Lab /SRR/Engineer/Samsung Electronics <p.iliutchenk@samsung.com>
Tue, 12 Nov 2019 14:53:17 +0000 (17:53 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 12 Nov 2019 14:53:17 +0000 (17:53 +0300)
* Implemented TransposeOp transformer and test for it

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/include/mir2loco.h
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index 71a58c7..43adaef 100644 (file)
@@ -44,6 +44,7 @@ public:
   void visit(mir::ops::ReshapeOp &op) override;
   void visit(mir::ops::SoftmaxOp &op) override;
   void visit(mir::ops::SubOp &op) override;
+  void visit(mir::ops::TransposeOp &op) override;
 
   void visit_fallback(mir::Operation &op) override;
 
index f7473e6..c5c08c6 100644 (file)
@@ -31,6 +31,7 @@
 #include "mir/ops/ReshapeOp.h"
 #include "mir/ops/SoftmaxOp.h"
 #include "mir/ops/SubOp.h"
+#include "mir/ops/TransposeOp.h"
 
 #include "mir/ShapeRange.h"
 
@@ -710,6 +711,24 @@ void Transformer::visit(mir::ops::SubOp &op)
   _mir2loco_map.emplace(&op, result);
 }
 
+void Transformer::visit(mir::ops::TransposeOp &op)
+{
+  const auto &axis_order = op.getAxisOrder();
+
+  auto transpose_node = _loco_graph->nodes()->create<loco::TensorTranspose>();
+  // Set Input
+  auto loco_it = _mir2loco_map.find(op.getInput(0)->getNode());
+  assert(loco_it != _mir2loco_map.end()); // can't find the input
+  transpose_node->input(loco_it->second);
+  // Set axis order
+  transpose_node->perm()->size(axis_order.size());
+  for (size_t i = 0; i < axis_order.size(); i++)
+    transpose_node->perm()->axis(i) = axis_order[i];
+  // Not set shape
+  // Add to map
+  _mir2loco_map.emplace(&op, transpose_node);
+}
+
 void Transformer::visit_fallback(mir::Operation &op) { throw std::runtime_error("NYI operation"); }
 
 std::unique_ptr<loco::Graph> Transformer::transform(mir::Graph *mir_graph)
index 2641c2f..38d3f63 100644 (file)
@@ -29,6 +29,7 @@
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
 #include "mir/ops/SoftmaxOp.h"
+#include "mir/ops/TransposeOp.h"
 
 #include <gtest/gtest.h>
 
@@ -674,3 +675,42 @@ TEST_F(TestTransformer_mir2loco, FullyConnected_Test)
   ASSERT_NE(push_node, nullptr);
   ASSERT_EQ(push_node->from(), decode_node);
 }
+
+TEST_F(TestTransformer_mir2loco, Transpose_Test)
+{
+  mir::Graph mir_graph;
+
+  mir::Shape input_shape = mir::Shape({2, 7, 9, 5});
+  auto *input = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
+  auto *transpose =
+      mir_graph.create<mir::ops::TransposeOp>(input, std::vector<std::size_t>{3, 0, 1, 2})
+          ->getOutput(0);
+  mir_graph.create<mir::ops::OutputOp>(transpose);
+  input->setName("x");
+  transpose->setName("y");
+
+  mir2loco::Transformer transformer;
+  auto loco_graph = transformer.transform(&mir_graph);
+  // Pull
+  auto inputs = loco_graph->inputs();
+  loco::Pull *pull_node = loco::pull_node(loco_graph.get(), 0);
+  ASSERT_NE(pull_node, nullptr);
+  // Transpose
+  auto pull_uses = loco::succs(pull_node);
+  ASSERT_EQ(pull_uses.size(), 1);
+  loco::TensorTranspose *transpose_node = dynamic_cast<loco::TensorTranspose *>(*pull_uses.begin());
+  ASSERT_NE(transpose_node, nullptr);
+  ASSERT_EQ(transpose_node->input(), pull_node);
+  // Push
+  auto transpose_uses = loco::succs(transpose_node);
+  ASSERT_EQ(transpose_uses.size(), 1);
+  loco::Push *push_node = dynamic_cast<loco::Push *>(*transpose_uses.begin());
+  ASSERT_NE(push_node, nullptr);
+  ASSERT_EQ(push_node->from(), transpose_node);
+  // Axis check
+  ASSERT_EQ(transpose_node->perm()->size(), 4);
+  ASSERT_EQ(transpose_node->perm()->axis(0), 3);
+  ASSERT_EQ(transpose_node->perm()->axis(1), 0);
+  ASSERT_EQ(transpose_node->perm()->axis(2), 1);
+  ASSERT_EQ(transpose_node->perm()->axis(3), 2);
+}