[mir_tflite] Switch to binary elementwise operations (#6411)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Fri, 9 Aug 2019 09:21:20 +0000 (12:21 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 9 Aug 2019 09:21:20 +0000 (12:21 +0300)
Switch to new binary elementwise operations.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
compiler/mir-tflite-importer/tflite_op_creator.cpp
compiler/mir-tflite-importer/tflite_op_creator.h

index 49b46ee..d3ff586 100644 (file)
 #include "tflite_op_creator.h"
 #include "schema_generated.h"
 
+#include "mir/ops/AddOp.h"
 #include "mir/ops/CappedReluOp.h"
 #include "mir/ops/ConcatOp.h"
 #include "mir/ops/ConstantOp.h"
 #include "mir/ops/Conv2DOp.h"
 #include "mir/ops/Deconv2DOp.h"
 #include "mir/ops/DepthwiseConv2DOp.h"
-#include "mir/ops/ElementwiseOp.h"
+#include "mir/ops/DivOp.h"
 #include "mir/ops/FullyConnectedOp.h"
 #include "mir/ops/LeakyReluOp.h"
+#include "mir/ops/MaxOp.h"
+#include "mir/ops/MulOp.h"
 #include "mir/ops/PadOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReduceOp.h"
@@ -37,6 +40,7 @@
 #include "mir/ops/SoftmaxOp.h"
 #include "mir/ops/SqrtOp.h"
 #include "mir/ops/SqueezeOp.h"
+#include "mir/ops/SubOp.h"
 #include "mir/ops/TanhOp.h"
 #include "mir/ops/TransposeOp.h"
 
@@ -119,7 +123,7 @@ TFLiteOpCreator::convertConv2D(const Conv2DOptions *opts,
 
   auto result =
       createOp<ops::Conv2DOp>(input, kernel, strides, padding_before, padding_after)->getOutput(0);
-  result = createAdd(result, bias);
+  result = createOp<ops::AddOp>(result, bias)->getOutput(0);
   return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
@@ -149,7 +153,7 @@ TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
   auto result =
       createOp<ops::DepthwiseConv2DOp>(input, kernel, strides, padding_before, padding_after)
           ->getOutput(0);
-  result = createAdd(result, bias);
+  result = createOp<ops::AddOp>(result, bias)->getOutput(0);
   return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
@@ -290,24 +294,27 @@ std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertAdd(const ::tflite::AddOptions *opts,
                             const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::add);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+  assert(inputs.size() == 2);
+  auto result = createOp<ops::AddOp>(inputs[0], inputs[1])->getOutput(0);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertSub(const ::tflite::SubOptions *opts,
                             const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::sub);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+  assert(inputs.size() == 2);
+  auto result = createOp<ops::SubOp>(inputs[0], inputs[1])->getOutput(0);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertMul(const ::tflite::MulOptions *opts,
                             const std::vector<mir::Operation::Output *> &inputs)
 {
+  assert(inputs.size() == 2);
   // Try to constant fold the operation in some cases.
-  if (inputs.size() == 2 && inputs[0]->getShape() == inputs[1]->getShape() &&
+  if (inputs[0]->getShape() == inputs[1]->getShape() &&
       opts->fused_activation_function() == ActivationFunctionType_NONE)
   {
     auto constant1_op = dynamic_cast<const ops::ConstantOp *>(inputs[0]->getNode());
@@ -336,30 +343,33 @@ TFLiteOpCreator::convertMul(const ::tflite::MulOptions *opts,
     }
   }
 
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::mul);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+  auto result = createOp<ops::MulOp>(inputs[0], inputs[1])->getOutput(0);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertDiv(const ::tflite::DivOptions *opts,
                             const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::div);
-  return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
+  assert(inputs.size() == 2);
+  auto result = createOp<ops::DivOp>(inputs[0], inputs[1])->getOutput(0);
+  return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertMax(const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::max);
-  return {result->getOutput(0)};
+  assert(inputs.size() == 2);
+  auto result = createOp<ops::MaxOp>(inputs[0], inputs[1])->getOutput(0);
+  return {result};
 }
 
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertSquaredDifference(const std::vector<mir::Operation::Output *> &inputs)
 {
-  auto result = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::sub)->getOutput(0);
-  result = createMul(result, result);
+  assert(inputs.size() == 2);
+  auto result = createOp<ops::SubOp>(inputs[0], inputs[1])->getOutput(0);
+  result = createOp<ops::MulOp>(result, result)->getOutput(0);
   return {result};
 }
 
@@ -395,7 +405,7 @@ TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions *op
   weights = createOp<ops::ConstantOp>(weights_tensor)->getOutput(0);
 
   auto result = createOp<ops::FullyConnectedOp>(flatten->getOutput(0), weights)->getOutput(0);
-  result = createAdd(result, bias);
+  result = createOp<ops::AddOp>(result, bias)->getOutput(0);
   return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
@@ -418,22 +428,6 @@ mir::Operation::Output *TFLiteOpCreator::addFusedActivation(mir::Operation::Outp
   }
 }
 
-mir::Operation::Output *TFLiteOpCreator::createAdd(mir::Operation::Output *arg1,
-                                                   mir::Operation::Output *arg2)
-{
-  std::vector<mir::Operation::Output *> inputs{arg1, arg2};
-  auto op = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::add);
-  return op->getOutput(0);
-}
-
-mir::Operation::Output *TFLiteOpCreator::createMul(mir::Operation::Output *arg1,
-                                                   mir::Operation::Output *arg2)
-{
-  std::vector<mir::Operation::Output *> inputs{arg1, arg2};
-  auto op = createOp<ops::ElementwiseOp>(inputs, ops::ElementwiseOp::OpType::mul);
-  return op->getOutput(0);
-}
-
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertSqueeze(const ::tflite::SqueezeOptions *opts,
                                 const std::vector<mir::Operation::Output *> &inputs)
index 8998129..3d25921 100644 (file)
@@ -21,7 +21,6 @@
 
 #include "mir/ops/CommonProps.h"
 #include "mir/ops/ReduceOp.h"
-#include "mir/ops/ElementwiseOp.h"
 #include "mir/Graph.h"
 #include "mir/TensorVariant.h"
 #include "mir/Scalar.h"
@@ -159,9 +158,6 @@ private:
   mir::Operation::Output *addFusedActivation(mir::Operation::Output *input,
                                              ::tflite::ActivationFunctionType activation_type);
 
-  mir::Operation::Output *createAdd(mir::Operation::Output *arg1, mir::Operation::Output *arg2);
-  mir::Operation::Output *createMul(mir::Operation::Output *arg1, mir::Operation::Output *arg2);
-
   template <typename OpType, typename... Types> mir::Operation *createOp(Types &&... args);
 };