[exo/tflite] Use Visitor in Operation Exporter (#4300)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 17 Jul 2019 03:49:55 +0000 (12:49 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 17 Jul 2019 03:49:55 +0000 (12:49 +0900)
This commit revises the implementation of Operation Exporter using loco
visitor interface.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/exo-tflite/src/OperationExporter.cpp

index 5707e55..d88f1f5 100644 (file)
 #include "TypeInference.h"
 #include "ShapeInference.h"
 
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
 using namespace flatbuffers;
 using namespace tflite;
 
 namespace
 {
 
-void exportRelu(loco::ReLU *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+class OperationExporter final : public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+  OperationExporter(FlatBufferBuilder &fbb, SerializedModelData &ctx) : builder{fbb}, gd{ctx}
+  {
+    // DO NOTHING
+  }
+
+public:
+  void visit(loco::ReLU *) final;
+  void visit(loco::Push *) final { /* DO NOTHING */}
+  void visit(loco::Pull *) final { /* DO NOTHING */}
+  void visit(loco::FeatureEncode *) final;
+  void visit(loco::FeatureDecode *) final;
+  void visit(loco::FilterEncode *) final;
+  void visit(loco::ConstGen *) final { /* skip, everything is done in exportOpDefinedTensors */}
+  void visit(loco::MaxPool2D *) final;
+  void visit(loco::AvgPool2D *) final;
+  void visit(loco::Conv2D *) final;
+  void visit(loco::TensorConcat *) final;
+  void visit(loco::BiasEncode *) final;
+  void visit(loco::TensorBiasAdd *) final;
+  void visit(loco::FeatureBiasAdd *) final;
+
+private:
+  FlatBufferBuilder &builder;
+  SerializedModelData &gd;
+};
+
+void OperationExporter::visit(loco::ReLU *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU);
   std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
@@ -36,7 +68,7 @@ void exportRelu(loco::ReLU *node, FlatBufferBuilder &builder, SerializedModelDat
   gd._operators.push_back(op_offset);
 }
 
-void exportMaxPool2D(loco::MaxPool2D *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::MaxPool2D *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAX_POOL_2D);
   std::vector<int32_t> inputs_vec{get_tensor_index(node->ifm())};
@@ -52,7 +84,7 @@ void exportMaxPool2D(loco::MaxPool2D *node, FlatBufferBuilder &builder, Serializ
   gd._operators.push_back(op_offset);
 }
 
-void exportAvgPool2D(loco::AvgPool2D *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::AvgPool2D *node)
 {
   // TFlite only support Valid convention of average pooling
   assert(node->convention() == loco::AvgPool2D::Convention::Valid);
@@ -71,7 +103,7 @@ void exportAvgPool2D(loco::AvgPool2D *node, FlatBufferBuilder &builder, Serializ
   gd._operators.push_back(op_offset);
 }
 
-void exportConv2D(loco::Conv2D *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::Conv2D *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONV_2D);
 
@@ -178,8 +210,7 @@ void exportAsTranspose(loco::Node *node, FlatBufferBuilder &builder,
   gd._operators.push_back(transpose_offset);
 }
 
-void exportFeatureEncode(loco::FeatureEncode *node, FlatBufferBuilder &builder,
-                         SerializedModelData &gd)
+void OperationExporter::visit(loco::FeatureEncode *node)
 {
   auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Feature> *>(node->encoder());
   auto perm = encoder->perm();
@@ -201,8 +232,7 @@ void exportFeatureEncode(loco::FeatureEncode *node, FlatBufferBuilder &builder,
   }
 }
 
-void exportFeatureDecode(loco::FeatureDecode *node, FlatBufferBuilder &builder,
-                         SerializedModelData &gd)
+void OperationExporter::visit(loco::FeatureDecode *node)
 {
   auto decoder = dynamic_cast<loco::PermutingDecoder<loco::Domain::Feature> *>(node->decoder());
   auto perm = decoder->perm();
@@ -224,8 +254,7 @@ void exportFeatureDecode(loco::FeatureDecode *node, FlatBufferBuilder &builder,
   }
 }
 
-void exportFilterEncode(loco::FilterEncode *node, FlatBufferBuilder &builder,
-                        SerializedModelData &gd)
+void OperationExporter::visit(loco::FilterEncode *node)
 {
   auto encoder = dynamic_cast<loco::PermutingEncoder<loco::Domain::Filter> *>(node->encoder());
   auto perm = encoder->perm();
@@ -248,8 +277,7 @@ void exportFilterEncode(loco::FilterEncode *node, FlatBufferBuilder &builder,
   }
 }
 
-void exportBiasAdd(loco::BiasAdd<loco::Domain::Tensor> *node, FlatBufferBuilder &builder,
-                   SerializedModelData &gd)
+void OperationExporter::visit(loco::BiasAdd<loco::Domain::Tensor> *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
   std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
@@ -262,7 +290,7 @@ void exportBiasAdd(loco::BiasAdd<loco::Domain::Tensor> *node, FlatBufferBuilder
   gd._operators.push_back(op_offset);
 }
 
-void exportBiasAdd(loco::FeatureBiasAdd *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::FeatureBiasAdd *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
   std::vector<int32_t> inputs_vec{get_tensor_index(node->value()), get_tensor_index(node->bias())};
@@ -276,7 +304,7 @@ void exportBiasAdd(loco::FeatureBiasAdd *node, FlatBufferBuilder &builder, Seria
 }
 
 /// @brief Export CONCATENATION of **TWO** tensors only
-void exportConcat(loco::TensorConcat *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+void OperationExporter::visit(loco::TensorConcat *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
   std::vector<int32_t> inputs_vec{get_tensor_index(node->lhs()), get_tensor_index(node->rhs())};
@@ -290,6 +318,8 @@ void exportConcat(loco::TensorConcat *node, FlatBufferBuilder &builder, Serializ
   gd._operators.push_back(op_offset);
 }
 
+void OperationExporter::visit(loco::BiasEncode *encode) { exportIdentity(encode, builder, gd); }
+
 void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
                 SerializedModelData &data)
 {
@@ -309,61 +339,10 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
     return;
   }
 
-  if (auto *relu = dynamic_cast<loco::ReLU *>(node))
-  {
-    exportRelu(relu, builder, data);
-  }
-  else if (dynamic_cast<loco::Pull *>(node))
-  {
-    // DO NOTHING
-  }
-  else if (dynamic_cast<loco::Push *>(node))
-  {
-    // DO NOTHING
-  }
-  else if (auto *encode = dynamic_cast<loco::FeatureEncode *>(node))
-  {
-    exportFeatureEncode(encode, builder, data);
-  }
-  else if (auto *decode = dynamic_cast<loco::FeatureDecode *>(node))
-  {
-    exportFeatureDecode(decode, builder, data);
-  }
-  else if (auto *encode = dynamic_cast<loco::FilterEncode *>(node))
-  {
-    exportFilterEncode(encode, builder, data);
-  }
-  else if (dynamic_cast<loco::ConstGen *>(node))
-  {
-    // skip, everything is done in exportOpDefinedTensors
-  }
-  else if (auto *max_pool = dynamic_cast<loco::MaxPool2D *>(node))
-  {
-    exportMaxPool2D(max_pool, builder, data);
-  }
-  else if (auto *avg_pool = dynamic_cast<loco::AvgPool2D *>(node))
-  {
-    exportAvgPool2D(avg_pool, builder, data);
-  }
-  else if (auto *conv2d = dynamic_cast<loco::Conv2D *>(node))
-  {
-    exportConv2D(conv2d, builder, data);
-  }
-  else if (auto *tconcat = dynamic_cast<loco::TensorConcat *>(node))
-  {
-    exportConcat(tconcat, builder, data);
-  }
-  else if (auto *encode = dynamic_cast<loco::BiasEncode *>(node))
-  {
-    exportIdentity(encode, builder, data);
-  }
-  else if (auto *biasadd = dynamic_cast<loco::BiasAdd<loco::Domain::Tensor> *>(node))
-  {
-    exportBiasAdd(biasadd, builder, data);
-  }
-  else if (auto *biasadd = dynamic_cast<loco::FeatureBiasAdd *>(node))
+  if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
   {
-    exportBiasAdd(biasadd, builder, data);
+    OperationExporter exporter{builder, data};
+    canonical_node->accept(&exporter);
   }
   else
   {