From 46c007826e4eae4b1ab6db5414e41e9c5ecb349a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 8 Jul 2019 10:50:37 +0900 Subject: [PATCH] [exo.tflite] Extract Type Inference as a separate phase (#4126) This commit extracts (Data) Type Inference from exportOpDefinedTensor phase. This chanage allows users to infer types without exporting tensors. Signed-off-by: Jonghyun Park --- contrib/exo-tflite/src/TFLExporterImpl.cpp | 5 ++ contrib/exo-tflite/src/TensorExporter.cpp | 4 +- contrib/exo-tflite/src/TypeInference.cpp | 81 ++++++++++++++++++++++++++++++ contrib/exo-tflite/src/TypeInference.h | 20 +++++++- 4 files changed, 107 insertions(+), 3 deletions(-) diff --git a/contrib/exo-tflite/src/TFLExporterImpl.cpp b/contrib/exo-tflite/src/TFLExporterImpl.cpp index 5910eb6..58f2ff8 100644 --- a/contrib/exo-tflite/src/TFLExporterImpl.cpp +++ b/contrib/exo-tflite/src/TFLExporterImpl.cpp @@ -16,6 +16,7 @@ #include "TFLExporterImpl.h" +#include "TypeInference.h" #include "TensorExporter.h" #include "OperationExporter.h" #include "ExporterUtils.h" @@ -54,6 +55,10 @@ flatbuffers::Offset TFLExporter::Impl::exportSubgraph(Serializ void TFLExporter::Impl::exportGraph(loco::Graph *graph) { + // Infer the type of each node + TypeInference::run(graph); + // TypeInference::get(node) now works + _builder.Clear(); SerializedModelData gd; diff --git a/contrib/exo-tflite/src/TensorExporter.cpp b/contrib/exo-tflite/src/TensorExporter.cpp index 8314bf9..eb5334d 100644 --- a/contrib/exo-tflite/src/TensorExporter.cpp +++ b/contrib/exo-tflite/src/TensorExporter.cpp @@ -65,8 +65,8 @@ void exportOpDefinedTensor(NodeT *node, FlatBufferBuilder &builder, SerializedMo auto shape_offset = encodeShape(builder, shape_description); // encode and register output tensor type - auto tensor_type = getOpResultType(node, gd); - gd._node_to_type[node] = tensor_type; + auto tensor_type = TypeInference::get(node); + // gd._node_to_type[node] = tensor_type; // encode and register output tensor buffer auto buffer = encodeOpBuffer(builder, node); diff --git a/contrib/exo-tflite/src/TypeInference.cpp b/contrib/exo-tflite/src/TypeInference.cpp index 3c31480..4d62aa7 100644 --- a/contrib/exo-tflite/src/TypeInference.cpp +++ b/contrib/exo-tflite/src/TypeInference.cpp @@ -18,6 +18,11 @@ #include "schema_generated.h" +#include +#include + +#include + #include namespace @@ -69,6 +74,11 @@ tflite::TensorType getOpResultType(loco::Pull *node, TypeContext &) return translateLocoTypeToTFLite(node->dtype()); } +tflite::TensorType getOpResultType(loco::Push *node, TypeContext &) +{ + return TypeInference::get(node->from()); +} + tflite::TensorType getOpResultType(loco::ReLU *node, TypeContext &gd) { return gd._node_to_type[node->input()]; @@ -131,6 +141,77 @@ tflite::TensorType getOpResultType(loco::BiasAdd *node, Ty return value_type; } +namespace +{ + +struct TypeAnnotation : public loco::NodeAnnotation +{ +public: + TypeAnnotation(const tflite::TensorType &type) : _type{type} + { + // DO NOTHING + } + +public: + const tflite::TensorType &type(void) const { return _type; } + +private: + tflite::TensorType _type; +}; + +class TypeAnnotator final : public loco::CanonicalNodeMutableVisitor +{ +public: + TypeAnnotator() = default; + +public: +#define NODE(NAME) \ + void visit(loco::NAME *node) final \ + { \ + auto t = getOpResultType(node, _ctx); \ + node->annot(stdex::make_unique(t)); \ + } + NODE(ConstGen) + NODE(Pull) + NODE(Push) + NODE(FeatureEncode) + NODE(FeatureDecode) + NODE(FilterEncode) + NODE(MaxPool2D) + NODE(AvgPool2D) + NODE(Conv2D) + NODE(ReLU) + NODE(TensorConcat) + NODE(BiasEncode) + NODE(TensorBiasAdd) +#undef NODE + +private: + // TODO Remove this variable + TypeContext _ctx; +}; + +} // namespace + +void TypeInference::run(loco::Graph *g) +{ + TypeAnnotator type_annotator; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + if (auto canonical_node = dynamic_cast(node)) + { + canonical_node->accept(&type_annotator); + } + } +} + +tflite::TensorType TypeInference::get(loco::Node *node) +{ + assert(node->annot() != nullptr); + return node->annot()->type(); +} + int32_t decodeShapeDimension(const loco::Dimension &dim) { if (!dim.known()) diff --git a/contrib/exo-tflite/src/TypeInference.h b/contrib/exo-tflite/src/TypeInference.h index 7a5c187..727a8b8 100644 --- a/contrib/exo-tflite/src/TypeInference.h +++ b/contrib/exo-tflite/src/TypeInference.h @@ -22,7 +22,7 @@ #include // Tensor type inference functions - +#if 0 tflite::TensorType getOpResultType(loco::ConstGen *node, TypeContext &); tflite::TensorType getOpResultType(loco::Pull *node, TypeContext &); @@ -46,6 +46,24 @@ tflite::TensorType getOpResultType(loco::TensorConcat *node, TypeContext &gd); tflite::TensorType getOpResultType(loco::BiasEncode *node, TypeContext &gd); tflite::TensorType getOpResultType(loco::BiasAdd *node, TypeContext &gd); +#endif + +/** + * @brief Annotate the type of each node as NodeAnnotation + * + * HOW TO USE + * + * TypeInference::run(g); + * + * TypeInference::get(g->nodes()->at(0)); + * TypeInference::get(g->nodes()->at(...)); + */ +struct TypeInference +{ + static void run(loco::Graph *g); + + static tflite::TensorType get(loco::Node *node); +}; // Shape inference functions -- 2.7.4