[exo.tflite] Extract Type Inference as a separate phase (#4126)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 8 Jul 2019 01:50:37 +0000 (10:50 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 8 Jul 2019 01:50:37 +0000 (10:50 +0900)
This commit extracts (Data) Type Inference from exportOpDefinedTensor
phase.

This chanage allows users to infer types without exporting tensors.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/exo-tflite/src/TFLExporterImpl.cpp
contrib/exo-tflite/src/TensorExporter.cpp
contrib/exo-tflite/src/TypeInference.cpp
contrib/exo-tflite/src/TypeInference.h

index 5910eb6..58f2ff8 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "TFLExporterImpl.h"
 
+#include "TypeInference.h"
 #include "TensorExporter.h"
 #include "OperationExporter.h"
 #include "ExporterUtils.h"
@@ -54,6 +55,10 @@ flatbuffers::Offset<tflite::SubGraph> TFLExporter::Impl::exportSubgraph(Serializ
 
 void TFLExporter::Impl::exportGraph(loco::Graph *graph)
 {
+  // Infer the type of each node
+  TypeInference::run(graph);
+  // TypeInference::get(node) now works
+
   _builder.Clear();
 
   SerializedModelData gd;
index 8314bf9..eb5334d 100644 (file)
@@ -65,8 +65,8 @@ void exportOpDefinedTensor(NodeT *node, FlatBufferBuilder &builder, SerializedMo
   auto shape_offset = encodeShape(builder, shape_description);
 
   // encode and register output tensor type
-  auto tensor_type = getOpResultType(node, gd);
-  gd._node_to_type[node] = tensor_type;
+  auto tensor_type = TypeInference::get(node);
+  // gd._node_to_type[node] = tensor_type;
 
   // encode and register output tensor buffer
   auto buffer = encodeOpBuffer(builder, node);
index 3c31480..4d62aa7 100644 (file)
 
 #include "schema_generated.h"
 
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <stdex/Memory.h>
+
 #include <type_traits>
 
 namespace
@@ -69,6 +74,11 @@ tflite::TensorType getOpResultType(loco::Pull *node, TypeContext &)
   return translateLocoTypeToTFLite(node->dtype());
 }
 
+tflite::TensorType getOpResultType(loco::Push *node, TypeContext &)
+{
+  return TypeInference::get(node->from());
+}
+
 tflite::TensorType getOpResultType(loco::ReLU *node, TypeContext &gd)
 {
   return gd._node_to_type[node->input()];
@@ -131,6 +141,77 @@ tflite::TensorType getOpResultType(loco::BiasAdd<loco::Domain::Tensor> *node, Ty
   return value_type;
 }
 
+namespace
+{
+
+struct TypeAnnotation : public loco::NodeAnnotation
+{
+public:
+  TypeAnnotation(const tflite::TensorType &type) : _type{type}
+  {
+    // DO NOTHING
+  }
+
+public:
+  const tflite::TensorType &type(void) const { return _type; }
+
+private:
+  tflite::TensorType _type;
+};
+
+class TypeAnnotator final : public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+  TypeAnnotator() = default;
+
+public:
+#define NODE(NAME)                                      \
+  void visit(loco::NAME *node) final                    \
+  {                                                     \
+    auto t = getOpResultType(node, _ctx);               \
+    node->annot(stdex::make_unique<TypeAnnotation>(t)); \
+  }
+  NODE(ConstGen)
+  NODE(Pull)
+  NODE(Push)
+  NODE(FeatureEncode)
+  NODE(FeatureDecode)
+  NODE(FilterEncode)
+  NODE(MaxPool2D)
+  NODE(AvgPool2D)
+  NODE(Conv2D)
+  NODE(ReLU)
+  NODE(TensorConcat)
+  NODE(BiasEncode)
+  NODE(TensorBiasAdd)
+#undef NODE
+
+private:
+  // TODO Remove this variable
+  TypeContext _ctx;
+};
+
+} // namespace
+
+void TypeInference::run(loco::Graph *g)
+{
+  TypeAnnotator type_annotator;
+
+  for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+  {
+    if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+    {
+      canonical_node->accept(&type_annotator);
+    }
+  }
+}
+
+tflite::TensorType TypeInference::get(loco::Node *node)
+{
+  assert(node->annot<TypeAnnotation>() != nullptr);
+  return node->annot<TypeAnnotation>()->type();
+}
+
 int32_t decodeShapeDimension(const loco::Dimension &dim)
 {
   if (!dim.known())
index 7a5c187..727a8b8 100644 (file)
@@ -22,7 +22,7 @@
 #include <loco/IR/Nodes.h>
 
 // Tensor type inference functions
-
+#if 0
 tflite::TensorType getOpResultType(loco::ConstGen *node, TypeContext &);
 
 tflite::TensorType getOpResultType(loco::Pull *node, TypeContext &);
@@ -46,6 +46,24 @@ tflite::TensorType getOpResultType(loco::TensorConcat *node, TypeContext &gd);
 tflite::TensorType getOpResultType(loco::BiasEncode *node, TypeContext &gd);
 
 tflite::TensorType getOpResultType(loco::BiasAdd<loco::Domain::Tensor> *node, TypeContext &gd);
+#endif
+
+/**
+ * @brief Annotate the type of each node as NodeAnnotation
+ *
+ * HOW TO USE
+ *
+ *   TypeInference::run(g);
+ *
+ *   TypeInference::get(g->nodes()->at(0));
+ *   TypeInference::get(g->nodes()->at(...));
+ */
+struct TypeInference
+{
+  static void run(loco::Graph *g);
+
+  static tflite::TensorType get(loco::Node *node);
+};
 
 // Shape inference functions