From 365881f11fef1ab7ada05ec790504c9e63890c80 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=B1=84=EC=84=B1=EC=9A=B0/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 13 Aug 2019 16:33:10 +0900 Subject: [PATCH] [exo-tflite] Introduce DepthwiseFilterEncode with exporting as reshape (#6524) * [exo-tflite] Introduce DepthwiseFilterEncode with exporting as reshape This commit introcude DepthwiseFilterEncode with exporting as reshape to exo-tflite. Signed-off-by: seongwoo * remove unused function. * add some comment. --- compiler/exo-tflite/src/OperationExporter.cpp | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index b39a2d6..7f9c07f 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -47,6 +47,7 @@ public: void visit(loco::FeatureEncode *) final; void visit(loco::FeatureDecode *) final; void visit(loco::FilterEncode *) final; + void visit(loco::DepthwiseFilterEncode *) final; void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */} void visit(loco::MaxPool2D *) final; void visit(loco::AvgPool2D *) final; @@ -298,6 +299,39 @@ void OperationExporter::visit(loco::FilterEncode *node) } } +void exportAsReshape(loco::Node *node, FlatBufferBuilder &builder, + std::vector &new_shape_vec, SerializedModelData &gd) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RESHAPE); + + std::vector inputs_vec{get_tensor_index(node->arg(0))}; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + auto new_shape_vec_offset = builder.CreateVector(new_shape_vec); + auto options = CreateReshapeOptions(builder, new_shape_vec_offset); + + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ReshapeOptions, options.Union()); + + gd._operators.push_back(op_offset); +} + +void OperationExporter::visit(loco::DepthwiseFilterEncode *node) +{ + auto ker = node->input(); // [H, W, C, M] + + // tflite represents filter as [1, H, W, C*M] where M is multiplier. + std::vector new_shape_vec(4); + new_shape_vec[0] = 1; + new_shape_vec[1] = ShapeInference::get(ker)._dims[0]; + new_shape_vec[2] = ShapeInference::get(ker)._dims[1]; + new_shape_vec[3] = ShapeInference::get(ker)._dims[2] * ShapeInference::get(ker)._dims[3]; + + exportAsReshape(node, builder, new_shape_vec, gd); +} + void OperationExporter::visit(loco::BiasAdd *node) { uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD); -- 2.7.4