From 557333218940ddada1f620596bfaeb21cac84fc8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 18 Sep 2019 13:12:22 +0900 Subject: [PATCH] [exo-tflite] Introduce EltwiseBinaryConverter and refactor (#7533) This will introduce EltwiseBinaryConverter and refactor EltwiseAddConverter and EltwiseMulConverter using this Signed-off-by: SaeHie Park --- .../src/Conversion/EltwiseAddConverter.cpp | 57 +------------- .../src/Conversion/EltwiseBinaryConverter.h | 87 ++++++++++++++++++++++ .../src/Conversion/EltwiseMulConverter.cpp | 57 +------------- 3 files changed, 91 insertions(+), 110 deletions(-) create mode 100644 compiler/exo-tflite/src/Conversion/EltwiseBinaryConverter.h diff --git a/compiler/exo-tflite/src/Conversion/EltwiseAddConverter.cpp b/compiler/exo-tflite/src/Conversion/EltwiseAddConverter.cpp index d3d078e..557f479 100644 --- a/compiler/exo-tflite/src/Conversion/EltwiseAddConverter.cpp +++ b/compiler/exo-tflite/src/Conversion/EltwiseAddConverter.cpp @@ -16,67 +16,14 @@ #include "EltwiseAddConverter.h" -#include "GraphBlock.h" -#include "Check.h" - -#include "Dialect/IR/TFLNodes.h" - -#include +#include "EltwiseBinaryConverter.h" namespace exo { bool EltwiseAddConverter::convert(loco::EltwiseAdd *origin) { - if (!loco::shape_known(origin)) - { - return false; - } - - if (loco::shape_get(origin).domain() == loco::Domain::Tensor) - { - auto tfl_add = origin->graph()->nodes()->create(); - tfl_add->x(origin->lhs()); - tfl_add->y(origin->rhs()); - - loco::replace(origin).with(tfl_add); - origin->lhs(nullptr); - origin->rhs(nullptr); - - return true; - } - else if (loco::shape_get(origin).domain() == loco::Domain::Feature) - { - /* - if EltwiseAdd's domain is Feature, EltwiseAdd is replaced with - FeatureDecoder-TFLAdd-FeatureEncoder. - - Before : - A (output: feature) -- loco::EltwiseAdd --- B (input:feature) - - After : - A -- loco::FeatureDecode -- locoex::TFLAdd -- loco::FeatureEncode --- B - - loco::EltwiseAdd (dead node) - */ - auto graph = origin->graph(); - auto dec_l = make_feature_decode(origin->lhs()); - auto dec_r = make_feature_decode(origin->rhs()); - auto tfl_add = graph->nodes()->create(); - { - tfl_add->x(dec_l); - tfl_add->y(dec_r); - } - auto enc = make_feature_encode(tfl_add); - - loco::replace(origin).with(enc); - origin->lhs(nullptr); - origin->rhs(nullptr); - - return true; - } - else - EXO_THROW("Not yet supported loco::Domain"); + return EltwiseBinaryConvert(origin); } } // namespace exo diff --git a/compiler/exo-tflite/src/Conversion/EltwiseBinaryConverter.h b/compiler/exo-tflite/src/Conversion/EltwiseBinaryConverter.h new file mode 100644 index 0000000..e83dfc6 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/EltwiseBinaryConverter.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_ELTWISEBINARY_CONVERTER_H__ +#define __CONVERSION_ELTWISEBINARY_CONVERTER_H__ + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include + +#include + +namespace exo +{ + +template bool EltwiseBinaryConvert(ELTWISEBIN *origin) +{ + if (!loco::shape_known(origin)) + { + return false; + } + + if (loco::shape_get(origin).domain() == loco::Domain::Tensor) + { + auto tfl_bin = origin->graph()->nodes()->template create(); + tfl_bin->x(origin->lhs()); + tfl_bin->y(origin->rhs()); + + loco::replace(origin).with(tfl_bin); + origin->lhs(nullptr); + origin->rhs(nullptr); + + return true; + } + else if (loco::shape_get(origin).domain() == loco::Domain::Feature) + { + /* + if ELTWISEBIN's domain is Feature, ELTWISEBIN is replaced with + FeatureDecoder -- TFLBIN -- FeatureEncoder. + + Before : + A (output: feature) -- loco::ELTWISEBIN -- B (input:feature) + + After : + A -- loco::FeatureDecode -- locoex::TFLBIN -- loco::FeatureEncode -- B + + loco::EltwiseBin (dead node) + */ + auto graph = origin->graph(); + auto dec_l = make_feature_decode(origin->lhs()); + auto dec_r = make_feature_decode(origin->rhs()); + auto tfl_new = graph->nodes()->template create(); + { + tfl_new->x(dec_l); + tfl_new->y(dec_r); + } + auto enc = make_feature_encode(tfl_new); + + loco::replace(origin).with(enc); + origin->lhs(nullptr); + origin->rhs(nullptr); + + return true; + } + else + EXO_THROW("Not yet supported loco::Domain"); +} + +} // namespace exo + +#endif // __CONVERSION_ELTWISEBINARY_CONVERTER_H__ diff --git a/compiler/exo-tflite/src/Conversion/EltwiseMulConverter.cpp b/compiler/exo-tflite/src/Conversion/EltwiseMulConverter.cpp index 5b48838..f7a4b82 100644 --- a/compiler/exo-tflite/src/Conversion/EltwiseMulConverter.cpp +++ b/compiler/exo-tflite/src/Conversion/EltwiseMulConverter.cpp @@ -16,67 +16,14 @@ #include "EltwiseMulConverter.h" -#include "GraphBlock.h" -#include "Check.h" - -#include "Dialect/IR/TFLNodes.h" - -#include +#include "EltwiseBinaryConverter.h" namespace exo { bool EltwiseMulConverter::convert(loco::EltwiseMul *origin) { - if (!loco::shape_known(origin)) - { - return false; - } - - if (loco::shape_get(origin).domain() == loco::Domain::Tensor) - { - auto tfl_mul = origin->graph()->nodes()->create(); - tfl_mul->x(origin->lhs()); - tfl_mul->y(origin->rhs()); - - loco::replace(origin).with(tfl_mul); - origin->lhs(nullptr); - origin->rhs(nullptr); - - return true; - } - else if (loco::shape_get(origin).domain() == loco::Domain::Feature) - { - /* - if EltwiseMul's domain is Feature, EltwiseMul is replaced with - FeatureDecoder-TFLMul-FeatureEncoder. - - Before : - A (output: feature) -- loco::EltwiseMul --- B (input:feature) - - After : - A -- loco::FeatureDecode -- locoex::TFLMul -- loco::FeatureEncode --- B - - loco::EltwiseMul (dead node) - */ - auto graph = origin->graph(); - auto dec_l = make_feature_decode(origin->lhs()); - auto dec_r = make_feature_decode(origin->rhs()); - auto tfl_mul = graph->nodes()->create(); - { - tfl_mul->x(dec_l); - tfl_mul->y(dec_r); - } - auto enc = make_feature_encode(tfl_mul); - - loco::replace(origin).with(enc); - origin->lhs(nullptr); - origin->rhs(nullptr); - - return true; - } - else - EXO_THROW("Not yet supported loco::Domain"); + return EltwiseBinaryConvert(origin); } } // namespace exo -- 2.7.4