From 3d7eb47cb8fcce7f11eb6a8bc3106f264bcf64b0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=82=A8=EA=B6=81=EC=84=9D/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 17 Oct 2019 15:49:26 +0900 Subject: [PATCH] [exo] Introduce TFLConcatenation exporter (#8259) This commit will introduce `TFLConcatenation` exporter in `exo` Signed-off-by: Seok NamKoong --- compiler/exo/src/TFLite/TFLOperationExporter.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.cpp b/compiler/exo/src/TFLite/TFLOperationExporter.cpp index 23b2695..3024a79 100644 --- a/compiler/exo/src/TFLite/TFLOperationExporter.cpp +++ b/compiler/exo/src/TFLite/TFLOperationExporter.cpp @@ -53,7 +53,7 @@ public: // FOR TFLNodes void visit(locoex::TFLAdd *) final; void visit(locoex::TFLAveragePool2D *) final; - // TODO TFLConcatenation + void visit(locoex::TFLConcatenation *) final; void visit(locoex::TFLConst *) final{/* skip, everything is done in exportOpDefinedTensors */}; void visit(locoex::TFLConv2D *) final; // TODO TFLDepthwiseConv2D @@ -134,7 +134,23 @@ void OperationExporter::visit(locoex::TFLAveragePool2D *node) export_pool_2d(node, tflite::BuiltinOperator_AVERAGE_POOL_2D); } -// TODO TFLConcatenation +void OperationExporter::visit(locoex::TFLConcatenation *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION); + std::vector inputs_vec; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + + for (uint32_t i = 0; i < node->numValues(); ++i) + inputs_vec.push_back(get_tensor_index(node->values(i))); + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto options = CreateConcatenationOptions(builder, node->axis(), + to_tflite_actfunc(node->fusedActivationFunction())); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_ConcatenationOptions, options.Union()); + gd._operators.push_back(op_offset); +} void OperationExporter::visit(locoex::TFLConv2D *node) { -- 2.7.4