[exo-tflite] Adding visit(locoex::TFLRelu) into OperationExporter (#6988)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 28 Aug 2019 07:12:41 +0000 (16:12 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 28 Aug 2019 07:12:41 +0000 (16:12 +0900)
* [exo-tflite] Making OperationExporter handle locoex::TFLRelu

This enables OperationExporter to export locoex::TFLRelu into tflite file.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* fix build error

compiler/exo-tflite/src/Dialect/IR/TFLNodeVisitor.h
compiler/exo-tflite/src/OperationExporter.cpp

index 93cb85f..038e600 100644 (file)
@@ -50,7 +50,7 @@ template <typename T> struct TFLNodeVisitor : public TFLNodeVisitorBase<T>
 #undef TFL_NODE
 
   /// @brief Default fallback
-  virtual T visit(const TFLNode *node) { assert(false); }
+  virtual T visit(const TFLNode *) { assert(false); }
 };
 
 /**
@@ -78,7 +78,7 @@ template <typename T> struct TFLNodeMutableVisitor : public TFLNodeMutableVisito
 #undef TFL_NODE
 
   /// @brief Default fallback
-  virtual T visit(TFLNode *node) { assert(false); }
+  virtual T visit(TFLNode *) { assert(false); }
 };
 
 } // namespace locoex
index c1092fe..d5b0c8a 100644 (file)
@@ -18,6 +18,9 @@
 #include "ExporterUtils.h"
 #include "ShapeInference.h"
 
+#include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLNodeVisitor.h"
+
 #include <loco/IR/CanonicalNode.h>
 #include <loco/IR/CanonicalNodeVisitor.h>
 #include <locoex/COpCall.h>
@@ -30,7 +33,8 @@ using namespace tflite;
 namespace
 {
 
-class OperationExporter final : public loco::CanonicalNodeMutableVisitor<void>
+class OperationExporter final : public locoex::TFLNodeMutableVisitor<void>,
+                                public loco::CanonicalNodeMutableVisitor<void>
 {
 public:
   OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx}
@@ -39,6 +43,10 @@ public:
   }
 
 public:
+  // FOR TFLNodes
+  void visit(locoex::TFLRelu *) final;
+
+  // FOR canonical nodes. These will be removed later
   void visit(loco::ReLU *) final;
   void visit(loco::ReLU6 *) final;
   void visit(loco::Tanh *) final;
@@ -72,6 +80,17 @@ private:
   SerializedModelData &gd;
 };
 
+void OperationExporter::visit(locoex::TFLRelu *node)
+{
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+  std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+  gd._operators.push_back(op_offset);
+}
+
 void OperationExporter::visit(loco::ReLU *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);