[mir-tflite-importer] Perform operation sanity checks while converting it (#6069)
authorIvan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 <ivan.vagin@samsung.com>
Wed, 7 Aug 2019 04:09:53 +0000 (07:09 +0300)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 7 Aug 2019 04:09:53 +0000 (13:09 +0900)
* [mir-tflite-importer] Perform operation sanity checks while creating

Perform operation sanity checks while creating it

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
compiler/mir-tflite-importer/tflite_importer.cpp
compiler/mir-tflite-importer/tflite_importer.h
compiler/mir-tflite-importer/tflite_op_creator.cpp
compiler/mir-tflite-importer/tflite_op_creator.h

index b23b4ef..284db5a 100644 (file)
@@ -67,90 +67,63 @@ void TfliteImporter::import()
   collectUnsupportedOps();
 }
 
+static const std::set<tflite::BuiltinOperator> supportedOperators = {
+    BuiltinOperator_ADD,
+    BuiltinOperator_AVERAGE_POOL_2D,
+    BuiltinOperator_CONCATENATION,
+    BuiltinOperator_CONV_2D,
+    BuiltinOperator_DEPTHWISE_CONV_2D,
+    BuiltinOperator_DIV,
+    BuiltinOperator_FULLY_CONNECTED,
+    BuiltinOperator_LEAKY_RELU,
+    BuiltinOperator_LOGISTIC,
+    BuiltinOperator_MAX_POOL_2D,
+    BuiltinOperator_MAXIMUM,
+    BuiltinOperator_MEAN,
+    BuiltinOperator_MUL,
+    BuiltinOperator_PAD,
+    BuiltinOperator_RELU,
+    BuiltinOperator_RELU6,
+    BuiltinOperator_RESHAPE,
+    BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+    BuiltinOperator_SHAPE,
+    BuiltinOperator_SLICE,
+    BuiltinOperator_SOFTMAX,
+    BuiltinOperator_SQRT,
+    BuiltinOperator_SQUARED_DIFFERENCE,
+    BuiltinOperator_SQUEEZE,
+    BuiltinOperator_STRIDED_SLICE,
+    BuiltinOperator_SUB,
+    BuiltinOperator_TANH,
+    BuiltinOperator_TRANSPOSE,
+    BuiltinOperator_TRANSPOSE_CONV,
+};
+
 void TfliteImporter::collectUnsupportedOps()
 {
+  std::set<std::string> errors;
   for (auto sub_graph : *(_modelPacked->subgraphs()))
     for (auto op : *(sub_graph->operators()))
-      processUnsupportedOp(op);
+    {
+      BuiltinOperator opcode = (*_opcodes)[op->opcode_index()]->builtin_code();
+      if (supportedOperators.find(opcode) == supportedOperators.end())
+      {
+        if (opcode <= BuiltinOperator_MAX)
+          errors.insert(std::string(EnumNameBuiltinOperator(opcode)) + ": unsupported operator");
+        else
+          errors.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
+      }
+    }
 
-  if (!_problemsOpSet.empty())
+  if (!errors.empty())
   {
     std::string msg("NNC can't load model. Detected problems:");
-    for (const auto &problemStr : _problemsOpSet)
-      msg.append("\n  * " + problemStr);
+    for (const auto &e : errors)
+      msg.append("\n  * " + e);
     throw std::runtime_error(msg);
   }
 }
 
-void TfliteImporter::processUnsupportedOp(const Operator *op)
-{
-  BuiltinOperator opcode = (*_opcodes)[op->opcode_index()]->builtin_code();
-  switch (opcode)
-  {
-    case BuiltinOperator_MAX_POOL_2D:
-    case BuiltinOperator_AVERAGE_POOL_2D:
-      _opCreator->checkPool2D(op->builtin_options_as<Pool2DOptions>(), _problemsOpSet);
-      break;
-    case BuiltinOperator_CONCATENATION:
-      _opCreator->checkConcatenation(op->builtin_options_as<ConcatenationOptions>(),
-                                     _problemsOpSet);
-      break;
-    case BuiltinOperator_CONV_2D:
-      _opCreator->checkConv2D(op->builtin_options_as<Conv2DOptions>(), _problemsOpSet);
-      break;
-    case BuiltinOperator_DEPTHWISE_CONV_2D:
-      _opCreator->checkDepthwiseConv2D(op->builtin_options_as<DepthwiseConv2DOptions>(),
-                                       _problemsOpSet);
-      break;
-    case BuiltinOperator_FULLY_CONNECTED:
-      _opCreator->checkFullyConnected(op->builtin_options_as<FullyConnectedOptions>(),
-                                      _problemsOpSet);
-      break;
-    case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
-      _opCreator->checkResizeNearestNeighbor(op->builtin_options_as<ResizeNearestNeighborOptions>(),
-                                             _problemsOpSet);
-      break;
-    case BuiltinOperator_STRIDED_SLICE:
-      _opCreator->checkStridedSlice(op->builtin_options_as<StridedSliceOptions>(), _problemsOpSet);
-      break;
-    case BuiltinOperator_SHAPE:
-      _opCreator->checkShape(op->builtin_options_as<ShapeOptions>(), _problemsOpSet);
-      break;
-    case BuiltinOperator_SOFTMAX:
-    case BuiltinOperator_SLICE:
-    case BuiltinOperator_RESHAPE:
-    case BuiltinOperator_SQUEEZE:
-    case BuiltinOperator_LOGISTIC:
-    case BuiltinOperator_SQRT:
-    case BuiltinOperator_PAD:
-    case BuiltinOperator_ADD:
-    case BuiltinOperator_SUB:
-    case BuiltinOperator_SQUARED_DIFFERENCE:
-    case BuiltinOperator_MUL:
-    case BuiltinOperator_MEAN:
-    case BuiltinOperator_MAXIMUM:
-    case BuiltinOperator_DIV:
-    case BuiltinOperator_TRANSPOSE_CONV:
-    case BuiltinOperator_TANH:
-    case BuiltinOperator_RELU:
-    case BuiltinOperator_RELU6:
-    case BuiltinOperator_TRANSPOSE:
-    case BuiltinOperator_LEAKY_RELU:
-      // No checks
-      break;
-    default:
-      if (opcode <= BuiltinOperator_MAX)
-      {
-        _problemsOpSet.insert(std::string(EnumNameBuiltinOperator(opcode)) +
-                              ": unsupported operator");
-      }
-      else
-      {
-        _problemsOpSet.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
-      }
-  }
-}
-
 std::unique_ptr<Graph> TfliteImporter::createIR()
 {
   walkGraphAndCreateMIR();
index 6b866c3..2df5e5f 100644 (file)
@@ -59,9 +59,6 @@ private:
   // Maps TFLite tensors indices to corresponding MIR operation outputs.
   std::map<int, mir::Operation::Output *> _tensorMap;
 
-  // set of strings describing incorrect parts of network and parts of network unsupported by NNC
-  std::set<std::string> _problemsOpSet;
-
   void import();
   std::unique_ptr<mir::Graph> createIR();
 
@@ -82,8 +79,6 @@ private:
    */
   void collectUnsupportedOps();
 
-  void processUnsupportedOp(const ::tflite::Operator *op);
-
   /**
    * @brief Mark output MIR nodes
    */
index 622495b..49b46ee 100644 (file)
@@ -97,12 +97,6 @@ static const mir::TensorVariant &extractTensor(const mir::Operation::Output *out
   return constant_op->getValue();
 }
 
-void TFLiteOpCreator::checkConv2D(const Conv2DOptions *opts,
-                                  std::set<std::string> &problems_ops_set)
-{
-  checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertConv2D(const Conv2DOptions *opts,
                                const std::vector<mir::Operation::Output *> &inputs)
@@ -129,12 +123,6 @@ TFLiteOpCreator::convertConv2D(const Conv2DOptions *opts,
   return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
-void TFLiteOpCreator::checkDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
-                                           std::set<std::string> &problems_ops_set)
-{
-  checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
                                         const std::vector<mir::Operation::Output *> &inputs)
@@ -165,12 +153,6 @@ TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
   return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
-void TFLiteOpCreator::checkConcatenation(const ConcatenationOptions *opts,
-                                         std::set<std::string> &problems_ops_set)
-{
-  checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertConcatenation(const ::tflite::ConcatenationOptions *opts,
                                       const std::vector<mir::Operation::Output *> &inputs)
@@ -179,12 +161,6 @@ TFLiteOpCreator::convertConcatenation(const ::tflite::ConcatenationOptions *opts
   return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
 }
 
-void TFLiteOpCreator::checkPool2D(const Pool2DOptions *opts,
-                                  std::set<std::string> &problems_ops_set)
-{
-  checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertMaxPool2D(const ::tflite::Pool2DOptions *opts,
                                   const std::vector<mir::Operation::Output *> &inputs)
@@ -292,17 +268,13 @@ TFLiteOpCreator::convertTransposeConv(const ::tflite::TransposeConvOptions *opts
   return {result->getOutput(0)};
 }
 
-void TFLiteOpCreator::checkResizeNearestNeighbor(const ::tflite::ResizeNearestNeighborOptions *opts,
-                                                 std::set<std::string> &problems_ops_set)
+std::vector<mir::Operation::Output *>
+TFLiteOpCreator::convertResizeNearestNeighbor(const ::tflite::ResizeNearestNeighborOptions *opts,
+                                              const std::vector<mir::Operation::Output *> &inputs)
 {
   if (opts->align_corners())
-    problems_ops_set.insert("'align_corners' is not currently supported");
-}
+    throw std::runtime_error("'align_corners' is not currently supported");
 
-std::vector<mir::Operation::Output *> TFLiteOpCreator::convertResizeNearestNeighbor(
-    const ::tflite::ResizeNearestNeighborOptions * /*opts*/,
-    const std::vector<mir::Operation::Output *> &inputs)
-{
   auto input = inputs.at(0);
   mir::Tensor<int32_t> size_tensor(extractTensor(inputs.at(1)));
 
@@ -404,12 +376,6 @@ TFLiteOpCreator::convertMean(const ::tflite::ReducerOptions *opts,
   return {result->getOutput(0)};
 }
 
-void TFLiteOpCreator::checkFullyConnected(const FullyConnectedOptions *opts,
-                                          std::set<std::string> &problems_ops_set)
-{
-  checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
 std::vector<mir::Operation::Output *>
 TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions *opts,
                                        const std::vector<mir::Operation::Output *> &inputs)
@@ -433,21 +399,9 @@ TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions *op
   return {addFusedActivation(result, opts->fused_activation_function())};
 }
 
-void TFLiteOpCreator::checkActivationType(ActivationFunctionType activation_type,
-                                          std::set<std::string> &problems_ops_set)
-{
-  if (activation_type != ActivationFunctionType_NONE &&
-      activation_type != ActivationFunctionType_RELU &&
-      activation_type != ActivationFunctionType_RELU6 &&
-      activation_type != ActivationFunctionType_TANH)
-    problems_ops_set.insert(std::string("Unsupported activation type: ") +
-                            EnumNameActivationFunctionType(activation_type));
-}
-
 mir::Operation::Output *TFLiteOpCreator::addFusedActivation(mir::Operation::Output *input,
                                                             ActivationFunctionType activation_type)
 {
-  // TODO Support other activation function types.
   switch (activation_type)
   {
     case ActivationFunctionType_NONE:
@@ -459,7 +413,8 @@ mir::Operation::Output *TFLiteOpCreator::addFusedActivation(mir::Operation::Outp
     case ActivationFunctionType_TANH:
       return createOp<ops::TanhOp>(input)->getOutput(0);
     default:
-      assert(false && "Unsupported activation types must be detected before this pass");
+      throw std::runtime_error(std::string("Unsupported activation type: ") +
+                               tflite::EnumNameActivationFunctionType(activation_type));
   }
 }
 
@@ -574,20 +529,16 @@ TFLiteOpCreator::convertTranspose(const ::tflite::TransposeOptions * /*opts*/,
   return {result->getOutput(0)};
 }
 
-void TFLiteOpCreator::checkStridedSlice(const ::tflite::StridedSliceOptions *opts,
-                                        std::set<std::string> &problems_ops_set)
+std::vector<mir::Operation::Output *>
+TFLiteOpCreator::convertStridedSlice(const ::tflite::StridedSliceOptions *opts,
+                                     const std::vector<mir::Operation::Output *> &inputs)
 {
   if (opts->ellipsis_mask() != 0)
-    problems_ops_set.insert("StridedSlice: parameter 'ellipsis_mask' is not supported.");
+    throw std::runtime_error("StridedSlice: parameter 'ellipsis_mask' is not supported.");
 
   if (opts->new_axis_mask() != 0)
-    problems_ops_set.insert("StridedSlice: parameter 'new_axis_mask' is not supported.");
-}
+    throw std::runtime_error("StridedSlice: parameter 'new_axis_mask' is not supported.");
 
-std::vector<mir::Operation::Output *>
-TFLiteOpCreator::convertStridedSlice(const ::tflite::StridedSliceOptions *opts,
-                                     const std::vector<mir::Operation::Output *> &inputs)
-{
   auto input = inputs.at(0);
   mir::Tensor<int32_t> begin_tensor(extractTensor(inputs.at(1)));
   mir::Tensor<int32_t> end_tensor(extractTensor(inputs.at(2)));
@@ -685,20 +636,16 @@ TFLiteOpCreator::convertLeakyReLU(const ::tflite::LeakyReluOptions *opts,
   return {result->getOutput(0)};
 }
 
-void TFLiteOpCreator::checkShape(const ::tflite::ShapeOptions *opts,
-                                 std::set<std::string> &problems_ops_set)
+std::vector<mir::Operation::Output *>
+TFLiteOpCreator::convertShape(const ::tflite::ShapeOptions *opts,
+                              const std::vector<mir::Operation::Output *> &inputs)
 {
   if (opts->out_type() != TensorType_INT32)
   {
-    problems_ops_set.insert(std::string("SHAPE: Unsupported tensor type: ") +
-                            EnumNameTensorType(opts->out_type()));
+    throw std::runtime_error(std::string("SHAPE: Unsupported tensor type: ") +
+                             EnumNameTensorType(opts->out_type()));
   }
-}
 
-std::vector<mir::Operation::Output *>
-TFLiteOpCreator::convertShape(const ::tflite::ShapeOptions * /*opts*/,
-                              const std::vector<mir::Operation::Output *> &inputs)
-{
   const auto &input_shape = inputs[0]->getShape();
   int32_t rank = input_shape.rank();
   Shape output_shape{rank};
index 7799be7..8998129 100644 (file)
@@ -149,27 +149,6 @@ public:
   convertShape(const ::tflite::ShapeOptions *opts,
                const std::vector<mir::Operation::Output *> &inputs);
 
-  void checkPool2D(const ::tflite::Pool2DOptions *opts, std::set<std::string> &problems_ops_set);
-
-  void checkConcatenation(const ::tflite::ConcatenationOptions *opts,
-                          std::set<std::string> &problems_ops_set);
-
-  void checkConv2D(const ::tflite::Conv2DOptions *opts, std::set<std::string> &problems_ops_set);
-
-  void checkDepthwiseConv2D(const ::tflite::DepthwiseConv2DOptions *opts,
-                            std::set<std::string> &problems_ops_set);
-
-  void checkFullyConnected(const ::tflite::FullyConnectedOptions *opts,
-                           std::set<std::string> &problems_ops_set);
-
-  void checkResizeNearestNeighbor(const ::tflite::ResizeNearestNeighborOptions *opts,
-                                  std::set<std::string> &problems_ops_set);
-
-  void checkStridedSlice(const ::tflite::StridedSliceOptions *opts,
-                         std::set<std::string> &problems_ops_set);
-
-  void checkShape(const ::tflite::ShapeOptions *opts, std::set<std::string> &problems_ops_set);
-
 private:
   Graph *_graph;
 
@@ -177,9 +156,6 @@ private:
       {::tflite::Padding_SAME, ops::PaddingType::Same},
       {::tflite::Padding_VALID, ops::PaddingType::Valid}};
 
-  void checkActivationType(::tflite::ActivationFunctionType activation_type,
-                           std::set<std::string> &problems_ops_set);
-
   mir::Operation::Output *addFusedActivation(mir::Operation::Output *input,
                                              ::tflite::ActivationFunctionType activation_type);