From 3a17fef39d0aac6f61632222bd08ec4191c4357a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 6 Aug 2019 19:51:00 +0300 Subject: [PATCH] [mir_tflite] Replace BiasAdd and Scale with Elementwise equivalents (#6294) `BiasAdd` and `Scale` are restricted versions of equivalent Elementwise ops and are going to be removed. Signed-off-by: Sergei Barannikov --- compiler/mir-tflite-importer/tflite_importer.cpp | 2 +- compiler/mir-tflite-importer/tflite_op_creator.cpp | 45 ++++++++++++++-------- compiler/mir-tflite-importer/tflite_op_creator.h | 3 ++ 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/compiler/mir-tflite-importer/tflite_importer.cpp b/compiler/mir-tflite-importer/tflite_importer.cpp index 9920e3f..b23b4ef 100644 --- a/compiler/mir-tflite-importer/tflite_importer.cpp +++ b/compiler/mir-tflite-importer/tflite_importer.cpp @@ -391,7 +391,7 @@ void TfliteImporter::setIrNodeNames() { // Setting names of the nodes. // Note: we change the computation graph, (for example, TFLite Conv2D - // turns into IR Conv2D->BiasAdd->ReLU), so not all of the nodes will have names. + // turns into IR Conv2D->Add->ReLU), so not all of the nodes will have names. for (auto iter : _tensorMap) { const Tensor *tensor = (*_tensors)[iter.first]; diff --git a/compiler/mir-tflite-importer/tflite_op_creator.cpp b/compiler/mir-tflite-importer/tflite_op_creator.cpp index 1755f2c..622495b 100644 --- a/compiler/mir-tflite-importer/tflite_op_creator.cpp +++ b/compiler/mir-tflite-importer/tflite_op_creator.cpp @@ -17,7 +17,6 @@ #include "tflite_op_creator.h" #include "schema_generated.h" -#include "mir/ops/BiasAddOp.h" #include "mir/ops/CappedReluOp.h" #include "mir/ops/ConcatOp.h" #include "mir/ops/ConstantOp.h" @@ -124,9 +123,10 @@ TFLiteOpCreator::convertConv2D(const Conv2DOptions *opts, calculatePadding(opts->padding(), input_shape, window_shape, strides, padding_before, padding_after); - auto result = createOp(input, kernel, strides, padding_before, padding_after); - result = createOp(result->getOutput(0), bias); - return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())}; + auto result = + createOp(input, kernel, strides, padding_before, padding_after)->getOutput(0); + result = createAdd(result, bias); + return {addFusedActivation(result, opts->fused_activation_function())}; } void TFLiteOpCreator::checkDepthwiseConv2D(const DepthwiseConv2DOptions *opts, @@ -159,9 +159,10 @@ TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions *opts, padding_after); auto result = - createOp(input, kernel, strides, padding_before, padding_after); - result = createOp(result->getOutput(0), bias); - return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())}; + createOp(input, kernel, strides, padding_before, padding_after) + ->getOutput(0); + result = createAdd(result, bias); + return {addFusedActivation(result, opts->fused_activation_function())}; } void TFLiteOpCreator::checkConcatenation(const ConcatenationOptions *opts, @@ -385,11 +386,9 @@ TFLiteOpCreator::convertMax(const std::vector &inputs) std::vector TFLiteOpCreator::convertSquaredDifference(const std::vector &inputs) { - auto result = createOp(inputs, ops::ElementwiseOp::OpType::sub); - result = createOp( - std::vector{result->getOutput(0), result->getOutput(0)}, - ops::ElementwiseOp::OpType::mul); - return {result->getOutput(0)}; + auto result = createOp(inputs, ops::ElementwiseOp::OpType::sub)->getOutput(0); + result = createMul(result, result); + return {result}; } std::vector @@ -429,9 +428,9 @@ TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions *op const auto &weights_tensor = mir::transposeTensor<1, 0>(extractTensor(weights)); weights = createOp(weights_tensor)->getOutput(0); - auto result = createOp(flatten->getOutput(0), weights); - result = createOp(result->getOutput(0), bias); - return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())}; + auto result = createOp(flatten->getOutput(0), weights)->getOutput(0); + result = createAdd(result, bias); + return {addFusedActivation(result, opts->fused_activation_function())}; } void TFLiteOpCreator::checkActivationType(ActivationFunctionType activation_type, @@ -464,6 +463,22 @@ mir::Operation::Output *TFLiteOpCreator::addFusedActivation(mir::Operation::Outp } } +mir::Operation::Output *TFLiteOpCreator::createAdd(mir::Operation::Output *arg1, + mir::Operation::Output *arg2) +{ + std::vector inputs{arg1, arg2}; + auto op = createOp(inputs, ops::ElementwiseOp::OpType::add); + return op->getOutput(0); +} + +mir::Operation::Output *TFLiteOpCreator::createMul(mir::Operation::Output *arg1, + mir::Operation::Output *arg2) +{ + std::vector inputs{arg1, arg2}; + auto op = createOp(inputs, ops::ElementwiseOp::OpType::mul); + return op->getOutput(0); +} + std::vector TFLiteOpCreator::convertSqueeze(const ::tflite::SqueezeOptions *opts, const std::vector &inputs) diff --git a/compiler/mir-tflite-importer/tflite_op_creator.h b/compiler/mir-tflite-importer/tflite_op_creator.h index 0632b68..7799be7 100644 --- a/compiler/mir-tflite-importer/tflite_op_creator.h +++ b/compiler/mir-tflite-importer/tflite_op_creator.h @@ -183,6 +183,9 @@ private: mir::Operation::Output *addFusedActivation(mir::Operation::Output *input, ::tflite::ActivationFunctionType activation_type); + mir::Operation::Output *createAdd(mir::Operation::Output *arg1, mir::Operation::Output *arg2); + mir::Operation::Output *createMul(mir::Operation::Output *arg1, mir::Operation::Output *arg2); + template mir::Operation *createOp(Types &&... args); }; -- 2.7.4