[nnc] Support for TensorFlow Lite SHAPE operator (#3025)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Tue, 12 Feb 2019 17:49:19 +0000 (20:49 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Tue, 12 Feb 2019 17:49:19 +0000 (20:49 +0300)
Initial support for TensorFlow Lite SHAPE operator in short term.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h

index 72ab0cf..926a07c 100644 (file)
@@ -96,6 +96,10 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
     case BuiltinOperator_STRIDED_SLICE:
       _opCreator->checkStridedSlice(op->builtin_options_as<StridedSliceOptions>(),
                                     _problemsOpSet);
+      break;
+    case BuiltinOperator_SHAPE:
+      _opCreator->checkShape(op->builtin_options_as<ShapeOptions>(), _problemsOpSet);
+      break;
     case BuiltinOperator_SOFTMAX:
     case BuiltinOperator_SLICE:
     case BuiltinOperator_RESHAPE:
@@ -271,6 +275,9 @@ void TfliteImporter::walkOperator(const Operator* op) {
     case BuiltinOperator_LEAKY_RELU:
       outputs = _opCreator->convertLeakyReLU(op->builtin_options_as<LeakyReluOptions>(), inputs);
       break;
+    case BuiltinOperator_SHAPE:
+      outputs = _opCreator->convertShape(op->builtin_options_as<ShapeOptions>(), inputs);
+      break;
     default:
       assert(false && "All unsupported types should have been found before this pass.");
   }
index b09f8d5..f7eb1e4 100644 (file)
@@ -294,11 +294,30 @@ TFLiteOpCreator::convertResizeNearestNeighbor(const ::tflite::ResizeNearestNeigh
   return {result->getOutput(0)};
 }
 
+mir::Operation::Output* TFLiteOpCreator::tryConvertToFloatTensor(mir::Operation::Output* arg) {
+  auto constant_op = dynamic_cast<mir::ops::ConstantOp*>(arg->getNode());
+  if (constant_op != nullptr && constant_op->getValue().getDataType() == mir::DTYPE::INT32) {
+    const mir::TensorVariant& int_tensor = constant_op->getValue();
+    mir::TensorVariant float_tensor(mir::DTYPE::FLOAT32, int_tensor.getShape());
+    mir::Tensor<int32_t> int_tensor_accessor(int_tensor);
+    mir::Tensor<float> float_tensor_accessor(float_tensor);
+    for (const auto& index : mir::ShapeRange(int_tensor.getShape()))
+      float_tensor_accessor.at(index) = static_cast<float>(int_tensor_accessor.at(index));
+    return createOp<ops::ConstantOp>(float_tensor)->getOutput(0);
+  } else {
+    return arg;
+  }
+}
+
 std::vector<mir::Operation::Output*>
 TFLiteOpCreator::createElementwise(ops::ElementwiseOp::OpType op_type,
                                    ::tflite::ActivationFunctionType activation,
                                    const std::vector<mir::Operation::Output*>& inputs) {
-  auto result = createOp<ops::ElementwiseOp>(inputs, op_type);
+  std::vector<mir::Operation::Output*> float_inputs;
+  for (auto* input : inputs)
+    float_inputs.push_back(tryConvertToFloatTensor(input));
+
+  auto result = createOp<ops::ElementwiseOp>(float_inputs, op_type);
   return {addFusedActivation(result->getOutput(0), activation)};
 }
 
@@ -515,6 +534,7 @@ TFLiteOpCreator::convertStridedSlice(const ::tflite::StridedSliceOptions* opts,
       squeeze_dims.push_back(axis);
   }
 
+  input = tryConvertToFloatTensor(input);
   auto result = createOp<ops::SliceOp>(input, start, size);
   result = createOp<ops::SqueezeOp>(result->getOutput(0), squeeze_dims);
   return {result->getOutput(0)};
@@ -529,4 +549,26 @@ TFLiteOpCreator::convertLeakyReLU(const ::tflite::LeakyReluOptions* opts,
   return {result->getOutput(0)};
 }
 
+void TFLiteOpCreator::checkShape(const ::tflite::ShapeOptions* opts,
+                                 std::set<std::string>& problem_ops_set) {
+  if (opts->out_type() != TensorType_INT32) {
+    problem_ops_set.insert(std::string("SHAPE: Unsupported tensor type: ") +
+                           EnumNameTensorType(opts->out_type()));
+  }
+}
+
+std::vector<mir::Operation::Output*>
+TFLiteOpCreator::convertShape(const ::tflite::ShapeOptions* opts,
+                              const std::vector<mir::Operation::Output*>& inputs) {
+  const auto& input_shape = inputs[0]->getShape();
+  int32_t rank = input_shape.rank();
+  Shape output_shape{rank};
+  std::vector<int32_t> data;
+  for (int32_t i = 0; i < rank; i++)
+    data.emplace_back(input_shape.dim(i));
+  mir::TensorVariant tensor(mir::DTYPE::INT32, output_shape, data.data());
+  auto result = createOp<ops::ConstantOp>(tensor);
+  return {result->getOutput(0)};
+}
+
 } // namespace nnc
index cfb954a..3d406b0 100644 (file)
@@ -136,6 +136,10 @@ public:
   convertLeakyReLU(const ::tflite::LeakyReluOptions* opts,
                    const std::vector<mir::Operation::Output*>& inputs);
 
+  std::vector<mir::Operation::Output*>
+  convertShape(const ::tflite::ShapeOptions* opts,
+               const std::vector<mir::Operation::Output*>& inputs);
+
   void checkPool2D(const ::tflite::Pool2DOptions* opts,
                    std::set<std::string>& problem_ops_set);
 
@@ -156,6 +160,9 @@ public:
 
   void checkStridedSlice(const ::tflite::StridedSliceOptions* opts,
                          std::set<std::string>& problem_ops_set);
+
+  void checkShape(const ::tflite::ShapeOptions* opts,
+                  std::set<std::string>& problem_ops_set);
 private:
   Graph* _graph;
 
@@ -170,6 +177,9 @@ private:
 
   template<typename OpType, typename... Types>
   mir::Operation* createOp(Types&&... args);
+
+  // FIXME This is a temporary hack needed to support SHAPE operator in short term.
+  mir::Operation::Output* tryConvertToFloatTensor(mir::Operation::Output* arg);
 };
 
 template<typename OpType, typename... Types>