From 68dec28da86c633c730e3b5d7c9bad54129f10ed Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 9 Aug 2019 12:29:45 +0300 Subject: [PATCH] [mir2loco] Switch to binary elementwise operations (#6413) Switch to new binary elementwise operations. Signed-off-by: Sergei Barannikov --- compiler/mir2loco/include/mir2loco.h | 2 +- compiler/mir2loco/src/mir2loco.cpp | 47 +++++++++++---------------------- compiler/mir2loco/src/mir2loco.test.cpp | 6 ++--- 3 files changed, 19 insertions(+), 36 deletions(-) diff --git a/compiler/mir2loco/include/mir2loco.h b/compiler/mir2loco/include/mir2loco.h index 469de9f..79ad08f 100644 --- a/compiler/mir2loco/include/mir2loco.h +++ b/compiler/mir2loco/include/mir2loco.h @@ -26,6 +26,7 @@ public: Transformer() = default; ~Transformer() = default; + void visit(mir::ops::AddOp &op) override; void visit(mir::ops::BatchNormOp &op) override; void visit(mir::ops::CappedReluOp &op) override; void visit(mir::ops::ConcatOp &op) override; @@ -34,7 +35,6 @@ public: void visit(mir::ops::DeConv2DOp &op) override; void visit(mir::ops::DepthwiseConv2DOp &op) override; void visit(mir::ops::DropoutOp &op) override; - void visit(mir::ops::ElementwiseOp &op) override; void visit(mir::ops::EluOp &op) override; void visit(mir::ops::FullyConnectedOp &op) override; void visit(mir::ops::GatherOp &op) override; diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index 57b49f5..0dede32 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -16,10 +16,10 @@ #include "mir2loco.h" +#include "mir/ops/AddOp.h" #include "mir/ops/ConcatOp.h" #include "mir/ops/ConstantOp.h" #include "mir/ops/Conv2DOp.h" -#include "mir/ops/ElementwiseOp.h" #include "mir/ops/PoolOp.h" #include "mir/ops/ReluOp.h" #include "mir/ops/ReshapeOp.h" @@ -122,6 +122,21 @@ loco::DataType ConvertDataType(mir::DataType data_type) } } // namespace +void Transformer::visit(mir::ops::AddOp &op) +{ + // Get Input + auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode()); + auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode()); + + auto result = _loco_graph->nodes()->create(); + result->lhs(lhs); + result->rhs(rhs); + + // Not set Shape + // Add to map + _mir2loco_map.emplace(&op, result); +} + void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); } void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); } @@ -233,36 +248,6 @@ void Transformer::visit(mir::ops::DepthwiseConv2DOp &op) { throw std::runtime_er void Transformer::visit(mir::ops::DropoutOp &op) { throw std::runtime_error("NYI"); } -void Transformer::visit(mir::ops::ElementwiseOp &op) -{ - // TODO Currently, MIR supports arbitrary number of inputs (>= 2). - if (op.getNumInputs() != 2) - throw std::runtime_error("NYI"); - - // Get Input - auto lhs = _mir2loco_map.at(op.getInput(0)->getProducer()->getNode()); - auto rhs = _mir2loco_map.at(op.getInput(1)->getProducer()->getNode()); - loco::Node *result = nullptr; - switch (op.getOpType()) - { - case mir::ops::ElementwiseOp::OpType::add: - { - auto add_node = _loco_graph->nodes()->create(); - add_node->lhs(lhs); - add_node->rhs(rhs); - result = add_node; - break; - } - default: - { - throw std::runtime_error("NYI"); - } - } - // Not set Shape - // Add to map - _mir2loco_map.emplace(&op, result); -} - void Transformer::visit(mir::ops::EluOp &op) { throw std::runtime_error("NYI"); } void Transformer::visit(mir::ops::FullyConnectedOp &op) { throw std::runtime_error("NYI"); } diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index 824299f..5ed3e23 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -16,10 +16,10 @@ #include "mir2loco.h" +#include "mir/ops/AddOp.h" #include "mir/ops/ConcatOp.h" #include "mir/ops/ConstantOp.h" #include "mir/ops/Conv2DOp.h" -#include "mir/ops/ElementwiseOp.h" #include "mir/ops/PoolOp.h" #include "mir/ops/ReluOp.h" #include "mir/ops/ReshapeOp.h" @@ -288,9 +288,7 @@ TEST_F(TestTransformer_mir2loco, Add_Test) auto *input1 = mir_graph.create("input1", input_shape); auto *input2 = mir_graph.create("input2", input_shape); - auto *add = mir_graph.create( - "bias_add", std::vector{input1->getOutput(0), input2->getOutput(0)}, - mir::ops::ElementwiseOp::OpType::add); + auto *add = mir_graph.create("add", input1->getOutput(0), input2->getOutput(0)); auto *output = mir_graph.create("output", add->getOutput(0)); mir2loco::Transformer transformer; -- 2.7.4