[exo-tflite] Introduce DepthwiseFilterEncode with exporting as reshape (#6524)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Tue, 13 Aug 2019 07:33:10 +0000 (16:33 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 13 Aug 2019 07:33:10 +0000 (16:33 +0900)
* [exo-tflite] Introduce DepthwiseFilterEncode with exporting as reshape

This commit introcude DepthwiseFilterEncode with exporting as reshape to
exo-tflite.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* remove unused function.

* add some comment.

compiler/exo-tflite/src/OperationExporter.cpp

index b39a2d6..7f9c07f 100644 (file)
@@ -47,6 +47,7 @@ public:
   void visit(loco::FeatureEncode *) final;
   void visit(loco::FeatureDecode *) final;
   void visit(loco::FilterEncode *) final;
+  void visit(loco::DepthwiseFilterEncode *) final;
   void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */}
   void visit(loco::MaxPool2D *) final;
   void visit(loco::AvgPool2D *) final;
@@ -298,6 +299,39 @@ void OperationExporter::visit(loco::FilterEncode *node)
   }
 }
 
+void exportAsReshape(loco::Node *node, FlatBufferBuilder &builder,
+                     std::vector<int32_t> &new_shape_vec, SerializedModelData &gd)
+{
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RESHAPE);
+
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->arg(0))};
+  std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+
+  auto new_shape_vec_offset = builder.CreateVector(new_shape_vec);
+  auto options = CreateReshapeOptions(builder, new_shape_vec_offset);
+
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+                                  tflite::BuiltinOptions_ReshapeOptions, options.Union());
+
+  gd._operators.push_back(op_offset);
+}
+
+void OperationExporter::visit(loco::DepthwiseFilterEncode *node)
+{
+  auto ker = node->input(); // [H, W, C, M]
+
+  // tflite represents filter as [1, H, W, C*M] where M is multiplier.
+  std::vector<int32_t> new_shape_vec(4);
+  new_shape_vec[0] = 1;
+  new_shape_vec[1] = ShapeInference::get(ker)._dims[0];
+  new_shape_vec[2] = ShapeInference::get(ker)._dims[1];
+  new_shape_vec[3] = ShapeInference::get(ker)._dims[2] * ShapeInference::get(ker)._dims[3];
+
+  exportAsReshape(node, builder, new_shape_vec, gd);
+}
+
 void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);