From 75a3ea38c7603bbbaaecfe5aca71d56b266e7b30 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 11 Jul 2019 10:21:33 +0900 Subject: [PATCH] [exo/tflite] Hide ShapeContext (#4187) This commit makes ShapeContext visible only to ShapeInference module, and revises all the other modules to use "ShapeInferece::get". Signed-off-by: Jonghyun Park --- contrib/exo-tflite/src/ExporterUtils.h | 10 +--------- contrib/exo-tflite/src/OperationExporter.cpp | 3 ++- contrib/exo-tflite/src/TensorExporter.cpp | 1 - contrib/exo-tflite/src/TypeInference.cpp | 13 +++++++++++++ 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/contrib/exo-tflite/src/ExporterUtils.h b/contrib/exo-tflite/src/ExporterUtils.h index 7ca2430..b69c4e9 100644 --- a/contrib/exo-tflite/src/ExporterUtils.h +++ b/contrib/exo-tflite/src/ExporterUtils.h @@ -48,14 +48,6 @@ struct ShapeDescription }; /** - * @brief Record the (tensor) shape of each loco node - */ -struct ShapeContext -{ - std::unordered_map _node_to_shape; -}; - -/** * @breif Record the information of T/F Lite SubGraph and its mapping to loco */ struct SubGraphContext @@ -67,7 +59,7 @@ struct SubGraphContext }; // Prerequisites for tflite::Model object creation -struct SerializedModelData final : public ShapeContext, public SubGraphContext +struct SerializedModelData final : public SubGraphContext { SerializedModelData() = default; SerializedModelData(const SerializedModelData &) = delete; diff --git a/contrib/exo-tflite/src/OperationExporter.cpp b/contrib/exo-tflite/src/OperationExporter.cpp index 17e34ac..4cbe51c 100644 --- a/contrib/exo-tflite/src/OperationExporter.cpp +++ b/contrib/exo-tflite/src/OperationExporter.cpp @@ -16,6 +16,7 @@ #include "OperationExporter.h" #include "ExporterUtils.h" +#include "TypeInference.h" using namespace flatbuffers; using namespace tflite; @@ -78,7 +79,7 @@ void exportConv2D(loco::Conv2D *node, FlatBufferBuilder &builder, SerializedMode // zero bias. auto *ker = dynamic_cast(node->ker()); assert(ker); - int32_t bias_vec_size = gd._node_to_shape[ker]._dims[0]; // output kernel count + int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count auto bias_vec_shape_offset = builder.CreateVector(std::vector{bias_vec_size}); size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t); diff --git a/contrib/exo-tflite/src/TensorExporter.cpp b/contrib/exo-tflite/src/TensorExporter.cpp index 76f05b0..40f086a 100644 --- a/contrib/exo-tflite/src/TensorExporter.cpp +++ b/contrib/exo-tflite/src/TensorExporter.cpp @@ -61,7 +61,6 @@ void exportOpDefinedTensor(NodeT *node, FlatBufferBuilder &builder, SerializedMo { // Create and register output tensor shape ShapeDescription shape_description = ShapeInference::get(node); - gd._node_to_shape[node] = shape_description; auto shape_offset = encodeShape(builder, shape_description); // encode and register output tensor type diff --git a/contrib/exo-tflite/src/TypeInference.cpp b/contrib/exo-tflite/src/TypeInference.cpp index 904b600..1cb3da7 100644 --- a/contrib/exo-tflite/src/TypeInference.cpp +++ b/contrib/exo-tflite/src/TypeInference.cpp @@ -221,6 +221,19 @@ tflite::TensorType TypeInference::get(loco::Node *node) return node->annot()->type(); } +namespace +{ + +/** + * @brief Record the (tensor) shape of each loco node + */ +struct ShapeContext +{ + std::unordered_map _node_to_shape; +}; + +} // namespace + int32_t decodeShapeDimension(const loco::Dimension &dim) { if (!dim.known()) -- 2.7.4