#include "mir/ops/ReshapeOp.h"
#include "mir/ops/SoftmaxOp.h"
#include "mir/ops/SubOp.h"
+#include "mir/ops/TransposeOp.h"
#include "mir/ShapeRange.h"
_mir2loco_map.emplace(&op, result);
}
+void Transformer::visit(mir::ops::TransposeOp &op)
+{
+ const auto &axis_order = op.getAxisOrder();
+
+ auto transpose_node = _loco_graph->nodes()->create<loco::TensorTranspose>();
+ // Set Input
+ auto loco_it = _mir2loco_map.find(op.getInput(0)->getNode());
+ assert(loco_it != _mir2loco_map.end()); // can't find the input
+ transpose_node->input(loco_it->second);
+ // Set axis order
+ transpose_node->perm()->size(axis_order.size());
+ for (size_t i = 0; i < axis_order.size(); i++)
+ transpose_node->perm()->axis(i) = axis_order[i];
+ // Not set shape
+ // Add to map
+ _mir2loco_map.emplace(&op, transpose_node);
+}
+
void Transformer::visit_fallback(mir::Operation &op) { throw std::runtime_error("NYI operation"); }
std::unique_ptr<loco::Graph> Transformer::transform(mir::Graph *mir_graph)
#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
#include "mir/ops/SoftmaxOp.h"
+#include "mir/ops/TransposeOp.h"
#include <gtest/gtest.h>
ASSERT_NE(push_node, nullptr);
ASSERT_EQ(push_node->from(), decode_node);
}
+
+TEST_F(TestTransformer_mir2loco, Transpose_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape = mir::Shape({2, 7, 9, 5});
+ auto *input = mir_graph.create<mir::ops::InputOp>(input_shape)->getOutput(0);
+ auto *transpose =
+ mir_graph.create<mir::ops::TransposeOp>(input, std::vector<std::size_t>{3, 0, 1, 2})
+ ->getOutput(0);
+ mir_graph.create<mir::ops::OutputOp>(transpose);
+ input->setName("x");
+ transpose->setName("y");
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+ // Pull
+ auto inputs = loco_graph->inputs();
+ loco::Pull *pull_node = loco::pull_node(loco_graph.get(), 0);
+ ASSERT_NE(pull_node, nullptr);
+ // Transpose
+ auto pull_uses = loco::succs(pull_node);
+ ASSERT_EQ(pull_uses.size(), 1);
+ loco::TensorTranspose *transpose_node = dynamic_cast<loco::TensorTranspose *>(*pull_uses.begin());
+ ASSERT_NE(transpose_node, nullptr);
+ ASSERT_EQ(transpose_node->input(), pull_node);
+ // Push
+ auto transpose_uses = loco::succs(transpose_node);
+ ASSERT_EQ(transpose_uses.size(), 1);
+ loco::Push *push_node = dynamic_cast<loco::Push *>(*transpose_uses.begin());
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(push_node->from(), transpose_node);
+ // Axis check
+ ASSERT_EQ(transpose_node->perm()->size(), 4);
+ ASSERT_EQ(transpose_node->perm()->axis(0), 3);
+ ASSERT_EQ(transpose_node->perm()->axis(1), 0);
+ ASSERT_EQ(transpose_node->perm()->axis(2), 1);
+ ASSERT_EQ(transpose_node->perm()->axis(3), 2);
+}