From e5c415af4963081ddb905827135152b1899978c9 Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Thu, 25 Jul 2019 06:50:09 +0300
Subject: [PATCH] [mir2loco] Introduce MIR to Loco transformer (#5751)
* Implemented core structure
* Supported 2 ops: Input and Output
* Unit test for input and output operations
Signed-off-by: Pavel Iliutchenko
---
compiler/mir2loco/.FORMATCHECKED | 0
compiler/mir2loco/CMakeLists.txt | 18 ++++
compiler/mir2loco/include/mir2loco.h | 68 +++++++++++++++
compiler/mir2loco/requires.cmake | 2 +
compiler/mir2loco/src/mir2loco.cpp | 141 ++++++++++++++++++++++++++++++++
compiler/mir2loco/src/mir2loco.test.cpp | 55 +++++++++++++
6 files changed, 284 insertions(+)
create mode 100644 compiler/mir2loco/.FORMATCHECKED
create mode 100644 compiler/mir2loco/CMakeLists.txt
create mode 100644 compiler/mir2loco/include/mir2loco.h
create mode 100644 compiler/mir2loco/requires.cmake
create mode 100644 compiler/mir2loco/src/mir2loco.cpp
create mode 100644 compiler/mir2loco/src/mir2loco.test.cpp
diff --git a/compiler/mir2loco/.FORMATCHECKED b/compiler/mir2loco/.FORMATCHECKED
new file mode 100644
index 0000000..e69de29
diff --git a/compiler/mir2loco/CMakeLists.txt b/compiler/mir2loco/CMakeLists.txt
new file mode 100644
index 0000000..9aaf723
--- /dev/null
+++ b/compiler/mir2loco/CMakeLists.txt
@@ -0,0 +1,18 @@
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(mir2loco STATIC ${SOURCES})
+target_include_directories(mir2loco PRIVATE src)
+target_include_directories(mir2loco PUBLIC include)
+target_link_libraries(mir2loco PUBLIC mir)
+target_link_libraries(mir2loco PUBLIC loco)
+
+nncc_find_package(GTest QUIET)
+
+if(NOT GTest_FOUND)
+ return()
+endif(NOT GTest_FOUND)
+
+GTest_AddTest(mir2loco_test ${TESTS})
+target_link_libraries(mir2loco_test mir2loco)
diff --git a/compiler/mir2loco/include/mir2loco.h b/compiler/mir2loco/include/mir2loco.h
new file mode 100644
index 0000000..aeda08c
--- /dev/null
+++ b/compiler/mir2loco/include/mir2loco.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mir/Graph.h"
+#include "loco.h"
+
+namespace mir2loco
+{
+
+class Transformer : public mir::Visitor
+{
+public:
+ Transformer() = default;
+ ~Transformer() = default;
+
+ void visit(mir::ops::BatchNormOp &op) override;
+ void visit(mir::ops::BiasAddOp &op) override;
+ void visit(mir::ops::CappedReluOp &op) override;
+ void visit(mir::ops::ConcatOp &op) override;
+ void visit(mir::ops::ConstantOp &op) override;
+ void visit(mir::ops::Conv2DOp &op) override;
+ void visit(mir::ops::DeConv2DOp &op) override;
+ void visit(mir::ops::DepthwiseConv2DOp &op) override;
+ void visit(mir::ops::DropoutOp &op) override;
+ void visit(mir::ops::ElementwiseOp &op) override;
+ void visit(mir::ops::EluOp &op) override;
+ void visit(mir::ops::FullyConnectedOp &op) override;
+ void visit(mir::ops::GatherOp &op) override;
+ void visit(mir::ops::GemmOp &op) override;
+ void visit(mir::ops::InputOp &op) override;
+ void visit(mir::ops::LeakyReluOp &op) override;
+ void visit(mir::ops::OutputOp &op) override;
+ void visit(mir::ops::PadOp &op) override;
+ void visit(mir::ops::PoolOp &op) override;
+ void visit(mir::ops::ReduceOp &op) override;
+ void visit(mir::ops::ReluOp &op) override;
+ void visit(mir::ops::ReshapeOp &op) override;
+ void visit(mir::ops::ResizeOp &op) override;
+ void visit(mir::ops::ScaleOp &op) override;
+ void visit(mir::ops::SigmoidOp &op) override;
+ void visit(mir::ops::SliceOp &op) override;
+ void visit(mir::ops::SoftmaxOp &op) override;
+ void visit(mir::ops::SqrtOp &op) override;
+ void visit(mir::ops::SqueezeOp &op) override;
+ void visit(mir::ops::TanhOp &op) override;
+ void visit(mir::ops::TransposeOp &op) override;
+
+ std::unique_ptr transform(mir::Graph *mir_graph);
+
+private:
+ std::unique_ptr _loco_graph;
+ std::unordered_map _mir2loco_map;
+};
+
+} // namespace mir2loco
diff --git a/compiler/mir2loco/requires.cmake b/compiler/mir2loco/requires.cmake
new file mode 100644
index 0000000..3648221
--- /dev/null
+++ b/compiler/mir2loco/requires.cmake
@@ -0,0 +1,2 @@
+require("loco")
+require("mir")
diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp
new file mode 100644
index 0000000..07ca57c
--- /dev/null
+++ b/compiler/mir2loco/src/mir2loco.cpp
@@ -0,0 +1,141 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mir2loco.h"
+
+namespace mir2loco
+{
+
+template void setupShape(const mir::Shape &shape, NodeType *node)
+{
+ node->rank(shape.rank());
+ for (int32_t i = 0; i < shape.rank(); i++)
+ {
+ node->dim(i) = static_cast(shape.dim(i));
+ }
+}
+
+void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ConcatOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ConstantOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::DeConv2DOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::DepthwiseConv2DOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::DropoutOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ElementwiseOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::EluOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::FullyConnectedOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::GatherOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::GemmOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::InputOp &op)
+{
+ auto pull_node = _loco_graph->nodes()->create();
+ // Set Type
+ pull_node->dtype(loco::DataType::FLOAT32); // TODO Support other types
+ // Set Shape
+ auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, pull_node);
+ // Set graph input
+ auto graph_input = _loco_graph->inputs()->create();
+ graph_input->name(op.getName());
+ graph_input->dtype(loco::DataType::FLOAT32); // TODO Support other types
+ graph_input->node(pull_node);
+ // Not set inputs
+ // Add to map
+ _mir2loco_map.emplace(&op, pull_node);
+}
+
+void Transformer::visit(mir::ops::LeakyReluOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::OutputOp &op)
+{
+ auto push_node = _loco_graph->nodes()->create();
+ // Set Input
+ auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode());
+ assert(loco_it != _mir2loco_map.end()); // can't find the input
+ push_node->from(loco_it->second);
+ // Set Shape
+ auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, push_node);
+ // Set graph output
+ auto graph_output = _loco_graph->outputs()->create();
+ graph_output->name(op.getName());
+ graph_output->dtype(loco::DataType::FLOAT32); // TODO Support other types
+ graph_output->node(push_node);
+ // Add to map
+ _mir2loco_map.emplace(&op, push_node);
+}
+
+void Transformer::visit(mir::ops::PadOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::PoolOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ReduceOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ReluOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ReshapeOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ResizeOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::ScaleOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::SigmoidOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::SliceOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::SoftmaxOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::SqrtOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::SqueezeOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::TanhOp &op) { throw std::runtime_error("NYI"); }
+
+void Transformer::visit(mir::ops::TransposeOp &op) { throw std::runtime_error("NYI"); }
+
+std::unique_ptr Transformer::transform(mir::Graph *mir_graph)
+{
+ _mir2loco_map.clear();
+ _loco_graph.release();
+ _loco_graph = loco::make_graph();
+
+ // Transform Nodes
+ mir_graph->accept(this);
+
+ // validate graph
+ assert(loco::valid(_loco_graph.get()));
+
+ return std::move(_loco_graph);
+}
+
+} // namespace mir2loco
diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp
new file mode 100644
index 0000000..c6e7263
--- /dev/null
+++ b/compiler/mir2loco/src/mir2loco.test.cpp
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mir2loco.h"
+
+#include
+
+class TestTransformer_mir2loco : public ::testing::Test
+{
+};
+
+TEST_F(TestTransformer_mir2loco, Input_Output_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape = mir::Shape({5, 6, 7, 8});
+ auto *input = mir_graph.create("input", input_shape);
+ auto *output = mir_graph.create("output", input->getOutput(0));
+ output->getOutput(0)->setShape(input_shape);
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ loco::Pull *pull_node = dynamic_cast(loco_graph->nodes()->at(0));
+ loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(1));
+
+ ASSERT_NE(pull_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(push_node->from(), pull_node);
+ // Shape check
+ ASSERT_EQ(pull_node->rank(), 4);
+ ASSERT_EQ(pull_node->dim(0), 5);
+ ASSERT_EQ(pull_node->dim(1), 6);
+ ASSERT_EQ(pull_node->dim(2), 7);
+ ASSERT_EQ(pull_node->dim(3), 8);
+
+ ASSERT_EQ(push_node->rank(), 4);
+ ASSERT_EQ(push_node->dim(0), 5);
+ ASSERT_EQ(push_node->dim(1), 6);
+ ASSERT_EQ(push_node->dim(2), 7);
+ ASSERT_EQ(push_node->dim(3), 8);
+}
--
2.7.4