From e5c415af4963081ddb905827135152b1899978c9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 25 Jul 2019 06:50:09 +0300 Subject: [PATCH] [mir2loco] Introduce MIR to Loco transformer (#5751) * Implemented core structure * Supported 2 ops: Input and Output * Unit test for input and output operations Signed-off-by: Pavel Iliutchenko --- compiler/mir2loco/.FORMATCHECKED | 0 compiler/mir2loco/CMakeLists.txt | 18 ++++ compiler/mir2loco/include/mir2loco.h | 68 +++++++++++++++ compiler/mir2loco/requires.cmake | 2 + compiler/mir2loco/src/mir2loco.cpp | 141 ++++++++++++++++++++++++++++++++ compiler/mir2loco/src/mir2loco.test.cpp | 55 +++++++++++++ 6 files changed, 284 insertions(+) create mode 100644 compiler/mir2loco/.FORMATCHECKED create mode 100644 compiler/mir2loco/CMakeLists.txt create mode 100644 compiler/mir2loco/include/mir2loco.h create mode 100644 compiler/mir2loco/requires.cmake create mode 100644 compiler/mir2loco/src/mir2loco.cpp create mode 100644 compiler/mir2loco/src/mir2loco.test.cpp diff --git a/compiler/mir2loco/.FORMATCHECKED b/compiler/mir2loco/.FORMATCHECKED new file mode 100644 index 0000000..e69de29 diff --git a/compiler/mir2loco/CMakeLists.txt b/compiler/mir2loco/CMakeLists.txt new file mode 100644 index 0000000..9aaf723 --- /dev/null +++ b/compiler/mir2loco/CMakeLists.txt @@ -0,0 +1,18 @@ +file(GLOB_RECURSE SOURCES "src/*.cpp") +file(GLOB_RECURSE TESTS "src/*.test.cpp") +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(mir2loco STATIC ${SOURCES}) +target_include_directories(mir2loco PRIVATE src) +target_include_directories(mir2loco PUBLIC include) +target_link_libraries(mir2loco PUBLIC mir) +target_link_libraries(mir2loco PUBLIC loco) + +nncc_find_package(GTest QUIET) + +if(NOT GTest_FOUND) + return() +endif(NOT GTest_FOUND) + +GTest_AddTest(mir2loco_test ${TESTS}) +target_link_libraries(mir2loco_test mir2loco) diff --git a/compiler/mir2loco/include/mir2loco.h b/compiler/mir2loco/include/mir2loco.h new file mode 100644 index 0000000..aeda08c --- /dev/null +++ b/compiler/mir2loco/include/mir2loco.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mir/Graph.h" +#include "loco.h" + +namespace mir2loco +{ + +class Transformer : public mir::Visitor +{ +public: + Transformer() = default; + ~Transformer() = default; + + void visit(mir::ops::BatchNormOp &op) override; + void visit(mir::ops::BiasAddOp &op) override; + void visit(mir::ops::CappedReluOp &op) override; + void visit(mir::ops::ConcatOp &op) override; + void visit(mir::ops::ConstantOp &op) override; + void visit(mir::ops::Conv2DOp &op) override; + void visit(mir::ops::DeConv2DOp &op) override; + void visit(mir::ops::DepthwiseConv2DOp &op) override; + void visit(mir::ops::DropoutOp &op) override; + void visit(mir::ops::ElementwiseOp &op) override; + void visit(mir::ops::EluOp &op) override; + void visit(mir::ops::FullyConnectedOp &op) override; + void visit(mir::ops::GatherOp &op) override; + void visit(mir::ops::GemmOp &op) override; + void visit(mir::ops::InputOp &op) override; + void visit(mir::ops::LeakyReluOp &op) override; + void visit(mir::ops::OutputOp &op) override; + void visit(mir::ops::PadOp &op) override; + void visit(mir::ops::PoolOp &op) override; + void visit(mir::ops::ReduceOp &op) override; + void visit(mir::ops::ReluOp &op) override; + void visit(mir::ops::ReshapeOp &op) override; + void visit(mir::ops::ResizeOp &op) override; + void visit(mir::ops::ScaleOp &op) override; + void visit(mir::ops::SigmoidOp &op) override; + void visit(mir::ops::SliceOp &op) override; + void visit(mir::ops::SoftmaxOp &op) override; + void visit(mir::ops::SqrtOp &op) override; + void visit(mir::ops::SqueezeOp &op) override; + void visit(mir::ops::TanhOp &op) override; + void visit(mir::ops::TransposeOp &op) override; + + std::unique_ptr transform(mir::Graph *mir_graph); + +private: + std::unique_ptr _loco_graph; + std::unordered_map _mir2loco_map; +}; + +} // namespace mir2loco diff --git a/compiler/mir2loco/requires.cmake b/compiler/mir2loco/requires.cmake new file mode 100644 index 0000000..3648221 --- /dev/null +++ b/compiler/mir2loco/requires.cmake @@ -0,0 +1,2 @@ +require("loco") +require("mir") diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp new file mode 100644 index 0000000..07ca57c --- /dev/null +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mir2loco.h" + +namespace mir2loco +{ + +template void setupShape(const mir::Shape &shape, NodeType *node) +{ + node->rank(shape.rank()); + for (int32_t i = 0; i < shape.rank(); i++) + { + node->dim(i) = static_cast(shape.dim(i)); + } +} + +void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ConcatOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ConstantOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::DeConv2DOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::DepthwiseConv2DOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::DropoutOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ElementwiseOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::EluOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::FullyConnectedOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::GatherOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::GemmOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::InputOp &op) +{ + auto pull_node = _loco_graph->nodes()->create(); + // Set Type + pull_node->dtype(loco::DataType::FLOAT32); // TODO Support other types + // Set Shape + auto &out_shape = op.getOutputShape(0); + setupShape(out_shape, pull_node); + // Set graph input + auto graph_input = _loco_graph->inputs()->create(); + graph_input->name(op.getName()); + graph_input->dtype(loco::DataType::FLOAT32); // TODO Support other types + graph_input->node(pull_node); + // Not set inputs + // Add to map + _mir2loco_map.emplace(&op, pull_node); +} + +void Transformer::visit(mir::ops::LeakyReluOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::OutputOp &op) +{ + auto push_node = _loco_graph->nodes()->create(); + // Set Input + auto loco_it = _mir2loco_map.find(op.getInput(0)->getProducer()->getNode()); + assert(loco_it != _mir2loco_map.end()); // can't find the input + push_node->from(loco_it->second); + // Set Shape + auto &out_shape = op.getOutputShape(0); + setupShape(out_shape, push_node); + // Set graph output + auto graph_output = _loco_graph->outputs()->create(); + graph_output->name(op.getName()); + graph_output->dtype(loco::DataType::FLOAT32); // TODO Support other types + graph_output->node(push_node); + // Add to map + _mir2loco_map.emplace(&op, push_node); +} + +void Transformer::visit(mir::ops::PadOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::PoolOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ReduceOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ReluOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ReshapeOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ResizeOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::ScaleOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::SigmoidOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::SliceOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::SoftmaxOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::SqrtOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::SqueezeOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::TanhOp &op) { throw std::runtime_error("NYI"); } + +void Transformer::visit(mir::ops::TransposeOp &op) { throw std::runtime_error("NYI"); } + +std::unique_ptr Transformer::transform(mir::Graph *mir_graph) +{ + _mir2loco_map.clear(); + _loco_graph.release(); + _loco_graph = loco::make_graph(); + + // Transform Nodes + mir_graph->accept(this); + + // validate graph + assert(loco::valid(_loco_graph.get())); + + return std::move(_loco_graph); +} + +} // namespace mir2loco diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp new file mode 100644 index 0000000..c6e7263 --- /dev/null +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mir2loco.h" + +#include + +class TestTransformer_mir2loco : public ::testing::Test +{ +}; + +TEST_F(TestTransformer_mir2loco, Input_Output_Test) +{ + mir::Graph mir_graph; + + mir::Shape input_shape = mir::Shape({5, 6, 7, 8}); + auto *input = mir_graph.create("input", input_shape); + auto *output = mir_graph.create("output", input->getOutput(0)); + output->getOutput(0)->setShape(input_shape); + + mir2loco::Transformer transformer; + auto loco_graph = transformer.transform(&mir_graph); + + loco::Pull *pull_node = dynamic_cast(loco_graph->nodes()->at(0)); + loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(1)); + + ASSERT_NE(pull_node, nullptr); + ASSERT_NE(push_node, nullptr); + ASSERT_EQ(push_node->from(), pull_node); + // Shape check + ASSERT_EQ(pull_node->rank(), 4); + ASSERT_EQ(pull_node->dim(0), 5); + ASSERT_EQ(pull_node->dim(1), 6); + ASSERT_EQ(pull_node->dim(2), 7); + ASSERT_EQ(pull_node->dim(3), 8); + + ASSERT_EQ(push_node->rank(), 4); + ASSERT_EQ(push_node->dim(0), 5); + ASSERT_EQ(push_node->dim(1), 6); + ASSERT_EQ(push_node->dim(2), 7); + ASSERT_EQ(push_node->dim(3), 8); +} -- 2.7.4