[exo/tflite] Hide ShapeContext (#4187)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 11 Jul 2019 01:21:33 +0000 (10:21 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 11 Jul 2019 01:21:33 +0000 (10:21 +0900)
This commit makes ShapeContext visible only to ShapeInference module,
and revises all the other modules to use "ShapeInferece::get".

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/exo-tflite/src/ExporterUtils.h
contrib/exo-tflite/src/OperationExporter.cpp
contrib/exo-tflite/src/TensorExporter.cpp
contrib/exo-tflite/src/TypeInference.cpp

index 7ca2430..b69c4e9 100644 (file)
@@ -48,14 +48,6 @@ struct ShapeDescription
 };
 
 /**
- * @brief Record the (tensor) shape of each loco node
- */
-struct ShapeContext
-{
-  std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
-};
-
-/**
  * @breif Record the information of T/F Lite SubGraph and its mapping to loco
  */
 struct SubGraphContext
@@ -67,7 +59,7 @@ struct SubGraphContext
 };
 
 // Prerequisites for tflite::Model object creation
-struct SerializedModelData final : public ShapeContext, public SubGraphContext
+struct SerializedModelData final : public SubGraphContext
 {
   SerializedModelData() = default;
   SerializedModelData(const SerializedModelData &) = delete;
index 17e34ac..4cbe51c 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "OperationExporter.h"
 #include "ExporterUtils.h"
+#include "TypeInference.h"
 
 using namespace flatbuffers;
 using namespace tflite;
@@ -78,7 +79,7 @@ void exportConv2D(loco::Conv2D *node, FlatBufferBuilder &builder, SerializedMode
   // zero bias.
   auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
   assert(ker);
-  int32_t bias_vec_size = gd._node_to_shape[ker]._dims[0]; // output kernel count
+  int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count
 
   auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size});
   size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t);
index 76f05b0..40f086a 100644 (file)
@@ -61,7 +61,6 @@ void exportOpDefinedTensor(NodeT *node, FlatBufferBuilder &builder, SerializedMo
 {
   // Create and register output tensor shape
   ShapeDescription shape_description = ShapeInference::get(node);
-  gd._node_to_shape[node] = shape_description;
   auto shape_offset = encodeShape(builder, shape_description);
 
   // encode and register output tensor type
index 904b600..1cb3da7 100644 (file)
@@ -221,6 +221,19 @@ tflite::TensorType TypeInference::get(loco::Node *node)
   return node->annot<TypeAnnotation>()->type();
 }
 
+namespace
+{
+
+/**
+ * @brief Record the (tensor) shape of each loco node
+ */
+struct ShapeContext
+{
+  std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
+};
+
+} // namespace
+
 int32_t decodeShapeDimension(const loco::Dimension &dim)
 {
   if (!dim.known())