From f1cb3741daeca5694eacc46ace3ead0262f4d4f7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 10 Sep 2019 16:00:41 +0900 Subject: [PATCH] [exo-tflite] Adding TFLAveragePool2D into OperationExporter (#7315) TFLAveragePool2D was added into OperationExporter. Signed-off-by: Hyun Sik Yoon --- compiler/exo-tflite/src/OperationExporter.cpp | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index e3f1cc4..fe6ba96 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -22,6 +22,8 @@ #include "Dialect/IR/TFLNodes.h" #include "Dialect/IR/TFLNodeVisitor.h" +#include "Check.h" + #include #include #include @@ -46,7 +48,7 @@ public: public: // FOR TFLNodes void visit(locoex::TFLAdd *) final; - // TODO TFLAveragePool2D + void visit(locoex::TFLAveragePool2D *) final; // TODO TFLConcatenation // TODO TFLConv2D // TODO TFLDepthwiseConv2D @@ -109,7 +111,27 @@ void OperationExporter::visit(locoex::TFLAdd *node) gd._operators.push_back(op_offset); } -// TODO TFLAveragePool2D +void OperationExporter::visit(locoex::TFLAveragePool2D *node) +{ + EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set"); + EXO_ASSERT(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED, + "fused activation function is not set"); + + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D); + std::vector inputs_vec{get_tensor_index(node->value())}; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + tflite::Padding padding = + node->padding() == locoex::Padding::VALID ? tflite::Padding_VALID : tflite::Padding_SAME; + + auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(), + node->filter()->w(), node->filter()->h()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} // TODO TFLConcatenation -- 2.7.4