From 1f4b709e0ced57de5d04deff658de2142755aa60 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 8 Jul 2019 16:21:11 +0900 Subject: [PATCH] [exo/tflite] Extract Shape Inference (#4138) This commit extracts shape inference phase from tensor export phase. Signed-off-by: Jonghyun Park --- contrib/exo-tflite/src/TFLExporterImpl.cpp | 4 ++ contrib/exo-tflite/src/TensorExporter.cpp | 2 +- contrib/exo-tflite/src/TypeInference.cpp | 77 ++++++++++++++++++++++++++++++ contrib/exo-tflite/src/TypeInference.h | 19 +++++++- 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/contrib/exo-tflite/src/TFLExporterImpl.cpp b/contrib/exo-tflite/src/TFLExporterImpl.cpp index 58f2ff8..37e36ed 100644 --- a/contrib/exo-tflite/src/TFLExporterImpl.cpp +++ b/contrib/exo-tflite/src/TFLExporterImpl.cpp @@ -59,6 +59,10 @@ void TFLExporter::Impl::exportGraph(loco::Graph *graph) TypeInference::run(graph); // TypeInference::get(node) now works + // Infer the shape of each node + ShapeInference::run(graph); + // ShapeInference::get(node) now works + _builder.Clear(); SerializedModelData gd; diff --git a/contrib/exo-tflite/src/TensorExporter.cpp b/contrib/exo-tflite/src/TensorExporter.cpp index eb5334d..9ad5afc 100644 --- a/contrib/exo-tflite/src/TensorExporter.cpp +++ b/contrib/exo-tflite/src/TensorExporter.cpp @@ -60,7 +60,7 @@ template void exportOpDefinedTensor(NodeT *node, FlatBufferBuilder &builder, SerializedModelData &gd) { // Create and register output tensor shape - ShapeDescription shape_description = getOpResultShape(node, gd); + ShapeDescription shape_description = ShapeInference::get(node); gd._node_to_shape[node] = shape_description; auto shape_offset = encodeShape(builder, shape_description); diff --git a/contrib/exo-tflite/src/TypeInference.cpp b/contrib/exo-tflite/src/TypeInference.cpp index 4d62aa7..f5bd4ab 100644 --- a/contrib/exo-tflite/src/TypeInference.cpp +++ b/contrib/exo-tflite/src/TypeInference.cpp @@ -238,6 +238,11 @@ ShapeDescription getOpResultShape(loco::Pull *node, ShapeContext &) return shape; } +ShapeDescription getOpResultShape(loco::Push *node, ShapeContext &gd) +{ + return gd._node_to_shape[node->from()]; +} + ShapeDescription getOpResultShape(loco::ConstGen *node, ShapeContext &) { ShapeDescription shape; @@ -546,3 +551,75 @@ ShapeDescription getOpResultShape(loco::BiasAdd *node, Sha return value_shape; } + +namespace +{ + +class ShapeAnnotation : public loco::NodeAnnotation +{ +public: + ShapeAnnotation(const ShapeDescription &shape) : _shape{shape} + { + // DO NOTHING + } + +public: + const ShapeDescription &shape(void) const { return _shape; } + +private: + ShapeDescription _shape; +}; + +class ShapeAnnotator final : public loco::CanonicalNodeMutableVisitor +{ +public: + ShapeAnnotator() = default; + +public: +#define NODE(NAME) \ + void visit(loco::NAME *node) final \ + { \ + auto s = getOpResultShape(node, _ctx); \ + node->annot(stdex::make_unique(s)); \ + _ctx._node_to_shape[node] = s; \ + } + NODE(ConstGen) + NODE(Pull) + NODE(Push) + NODE(FeatureEncode) + NODE(FeatureDecode) + NODE(FilterEncode) + NODE(MaxPool2D) + NODE(AvgPool2D) + NODE(Conv2D) + NODE(ReLU) + NODE(TensorConcat) + NODE(BiasEncode) + NODE(TensorBiasAdd) +#undef NODE + +private: + // TODO Remove this variable + ShapeContext _ctx; +}; + +} // namespace + +void ShapeInference::run(loco::Graph *g) +{ + ShapeAnnotator shape_annotator; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + if (auto canonical_node = dynamic_cast(node)) + { + canonical_node->accept(&shape_annotator); + } + } +} + +ShapeDescription ShapeInference::get(loco::Node *node) +{ + assert(node->annot() != nullptr); + return node->annot()->shape(); +} diff --git a/contrib/exo-tflite/src/TypeInference.h b/contrib/exo-tflite/src/TypeInference.h index 727a8b8..9dcf79d 100644 --- a/contrib/exo-tflite/src/TypeInference.h +++ b/contrib/exo-tflite/src/TypeInference.h @@ -66,7 +66,7 @@ struct TypeInference }; // Shape inference functions - +#if 0 ShapeDescription getOpResultShape(loco::Pull *node, ShapeContext &); ShapeDescription getOpResultShape(loco::ConstGen *node, ShapeContext &); @@ -90,5 +90,22 @@ ShapeDescription getOpResultShape(loco::TensorConcat *node, ShapeContext &gd); ShapeDescription getOpResultShape(loco::BiasEncode *node, ShapeContext &gd); ShapeDescription getOpResultShape(loco::BiasAdd *node, ShapeContext &gd); +#endif + +/** + * @brief Annotate the shape of each node as a node annotation + * + * HOW TO USE + * + * ShapeInference::run(g); + * + * ShapeInference::get(g->nodes()->at(..)); + */ +struct ShapeInference +{ + static void run(loco::Graph *g); + + static ShapeDescription get(loco::Node *node); +}; #endif // __TYPE_INFERENCE_H__ -- 2.7.4