[mir_tflite] Replace BiasAdd and Scale with Elementwise equivalents (#6294)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Tue, 6 Aug 2019 16:51:00 +0000 (19:51 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 6 Aug 2019 16:51:00 +0000 (19:51 +0300)
`BiasAdd` and `Scale` are restricted versions of equivalent Elementwise ops and are going to be removed.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-tflite-importer/tflite_importer.cpp
compiler/mir-tflite-importer/tflite_op_creator.cpp
compiler/mir-tflite-importer/tflite_op_creator.h

index 9920e3f..b23b4ef 100644 (file)
@@ -391,7 +391,7 @@ void TfliteImporter::setIrNodeNames()
 {
   // Setting names of the nodes.
   // Note: we change the computation graph, (for example, TFLite Conv2D
-  // turns into IR Conv2D->BiasAdd->ReLU), so not all of the nodes will have names.
+  // turns into IR Conv2D->Add->ReLU), so not all of the nodes will have names.
   for (auto iter : _tensorMap)
   {
     const Tensor *tensor = (*_tensors)[iter.first];
index 1755f2c..622495b 100644 (file)
@@ -17,7 +17,6 @@
 #include "tflite_op_creator.h"
 #include "schema_generated.h"
 
-#include "mir/ops/BiasAddOp.h"
 #include "mir/ops/CappedReluOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
@@ -124,9 +123,10 @@ TFLiteOpCreator::convertConv2D(const Conv2DOptions *opts,
   calculatePadding(opts->padding(), input_shape, window_shape, strides, padding_before,
                    padding_after);
 
-  auto result = createOp<ops::Conv2DOp>(input, kernel, strides, padding_before, padding_after);
-  result = createOp<ops::BiasAddOp>(result->getOutput(0), bias);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+  auto result =
+      createOp<ops::Conv2DOp>(input, kernel, strides, padding_before, padding_after)->getOutput(0);
+  result = createAdd(result, bias);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 void TFLiteOpCreator::checkDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
@@ -159,9 +159,10 @@ TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
                    padding_after);
 
   auto result =
-      createOp<ops::DepthwiseConv2DOp>(input, kernel, strides, padding_before, padding_after);
-  result = createOp<ops::BiasAddOp>(result->getOutput(0), bias);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+      createOp<ops::DepthwiseConv2DOp>(input, kernel, strides, padding_before, padding_after)
+          ->getOutput(0);
+  result = createAdd(result, bias);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 void TFLiteOpCreator::checkConcatenation(const ConcatenationOptions *opts,
@@ -385,11 +386,9 @@ TFLiteOpCreator::convertMax(const std::vector<mir::Operation::Output *> &inputs)
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertSquaredDifference(const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::sub);
-  result = createOp<ops::ElementwiseOp>(
-      std::vector<mir::Operation::Output *>{result->getOutput(0), result->getOutput(0)},
-      ops::ElementwiseOp::OpType::mul);
-  return {result->getOutput(0)};
+  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::sub)->getOutput(0);
+  result = createMul(result, result);
+  return {result};
 }
 
 std::vector<mir::Operation::Output *>
@@ -429,9 +428,9 @@ TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions *op
   const auto &weights_tensor = mir::transposeTensor<1, 0>(extractTensor(weights));
   weights = createOp<ops::ConstantOp>(weights_tensor)->getOutput(0);
 
-  auto result = createOp<ops::FullyConnectedOp>(flatten->getOutput(0), weights);
-  result = createOp<ops::BiasAddOp>(result->getOutput(0), bias);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+  auto result = createOp<ops::FullyConnectedOp>(flatten->getOutput(0), weights)->getOutput(0);
+  result = createAdd(result, bias);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 void TFLiteOpCreator::checkActivationType(ActivationFunctionType activation_type,
@@ -464,6 +463,22 @@ mir::Operation::Output *TFLiteOpCreator::addFusedActivation(mir::Operation::Outp
   }
 }
 
+mir::Operation::Output *TFLiteOpCreator::createAdd(mir::Operation::Output *arg1,
+                                                   mir::Operation::Output *arg2)
+{
+  std::vector<mir::Operation::Output *> inputs{arg1, arg2};
+  auto op = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::add);
+  return op->getOutput(0);
+}
+
+mir::Operation::Output *TFLiteOpCreator::createMul(mir::Operation::Output *arg1,
+                                                   mir::Operation::Output *arg2)
+{
+  std::vector<mir::Operation::Output *> inputs{arg1, arg2};
+  auto op = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::mul);
+  return op->getOutput(0);
+}
+
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertSqueeze(const ::tflite::SqueezeOptions *opts,
                                 const std::vector<mir::Operation::Output *> &inputs)
index 0632b68..7799be7 100644 (file)
@@ -183,6 +183,9 @@ private:
   mir::Operation::Output *addFusedActivation(mir::Operation::Output *input,
                                              ::tflite::ActivationFunctionType activation_type);
 
+  mir::Operation::Output *createAdd(mir::Operation::Output *arg1, mir::Operation::Output *arg2);
+  mir::Operation::Output *createMul(mir::Operation::Output *arg1, mir::Operation::Output *arg2);
+
   template <typename OpType, typename... Types> mir::Operation *createOp(Types &&... args);
 };