[exo-tflite] Adding TFLAveragePool2D into OperationExporter (#7315)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 10 Sep 2019 07:00:41 +0000 (16:00 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 10 Sep 2019 07:00:41 +0000 (16:00 +0900)
TFLAveragePool2D was added into OperationExporter.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/OperationExporter.cpp

index e3f1cc4..fe6ba96 100644 (file)
@@ -22,6 +22,8 @@
 #include "Dialect/IR/TFLNodes.h"
 #include "Dialect/IR/TFLNodeVisitor.h"
 
+#include "Check.h"
+
 #include <loco/IR/CanonicalNode.h>
 #include <loco/IR/CanonicalNodeVisitor.h>
 #include <locoex/COpCall.h>
@@ -46,7 +48,7 @@ public:
 public:
   // FOR TFLNodes
   void visit(locoex::TFLAdd *) final;
-  // TODO TFLAveragePool2D
+  void visit(locoex::TFLAveragePool2D *) final;
   // TODO TFLConcatenation
   // TODO TFLConv2D
   // TODO TFLDepthwiseConv2D
@@ -109,7 +111,27 @@ void OperationExporter::visit(locoex::TFLAdd *node)
   gd._operators.push_back(op_offset);
 }
 
-// TODO TFLAveragePool2D
+void OperationExporter::visit(locoex::TFLAveragePool2D *node)
+{
+  EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set");
+  EXO_ASSERT(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED,
+             "fused activation function is not set");
+
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D);
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
+  std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+
+  tflite::Padding padding =
+      node->padding() == locoex::Padding::VALID ? tflite::Padding_VALID : tflite::Padding_SAME;
+
+  auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+                                     node->filter()->w(), node->filter()->h());
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+                                  tflite::BuiltinOptions_Pool2DOptions, options.Union());
+  gd._operators.push_back(op_offset);
+}
 
 // TODO TFLConcatenation