[enco/frontend] Introduce TensorContext and TensorBags (#2188)
author박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 9 Nov 2018 05:36:51 +0000 (14:36 +0900)
committer박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 9 Nov 2018 05:36:51 +0000 (14:36 +0900)
* [enco/frontend] Introduce TensorContext and TensorBags

This will introduce TensorContext to extract and hold informations of
tensors such as shape and name and pre-creates coco::Bag for each tensors

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* apply comments

contrib/enco/frontend/tflite/src/Frontend.cpp

index 7c31d15..81ba6b4 100644 (file)
@@ -18,6 +18,8 @@
 
 #include <nncc/core/ADT/tensor/Shape.h>
 
+#include <map>
+
 using namespace nncc::core::ADT;
 
 namespace tflimport
@@ -61,6 +63,57 @@ tensor::Shape as_tensor_shape(const flatbuffers::Vector<int32_t> *shape)
 }
 
 /**
+ * @brief Extracts and holds operand(tensor) information such as name and shape
+ */
+class TensorContext
+{
+public:
+  void prepare(const tflite::SubGraph *graph)
+  {
+    for (uint32_t tensor_id = 0; tensor_id < graph->tensors()->size(); ++tensor_id)
+    {
+      auto const tensor_info = graph->tensors()->Get(tensor_id);
+      auto const tensor_name = tensor_info->name()->str();
+      auto const tensor_shape = as_tensor_shape(tensor_info->shape());
+
+      _name_ctx[tensor_id] = tensor_name;
+      _shape_ctx[tensor_id] = tensor_shape;
+    }
+  }
+
+  const std::string &name(uint32_t tensor_id) { return _name_ctx[tensor_id]; }
+  const tensor::Shape &shape(uint32_t tensor_id) { return _shape_ctx[tensor_id]; }
+
+private:
+  std::map<uint32_t, std::string> _name_ctx;
+  std::map<uint32_t, tensor::Shape> _shape_ctx;
+};
+
+/**
+ * @brief Pre-creates coco:Bags for each operands(tensors)
+ */
+class TensorBags
+{
+public:
+  void prepare(const tflite::SubGraph *graph, std::unique_ptr<coco::Module> &m)
+  {
+    for (uint32_t tensor_id = 0; tensor_id < graph->tensors()->size(); ++tensor_id)
+    {
+      auto const tensor_info = graph->tensors()->Get(tensor_id);
+      auto const tensor_shape = as_tensor_shape(tensor_info->shape());
+      auto const tensor_bag = m->entity()->bag()->create(num_elements(tensor_shape));
+
+      _bag_ctx[tensor_id] = tensor_bag;
+    }
+  }
+
+  coco::Bag *bag(int32_t tensor_id) { return _bag_ctx[tensor_id]; }
+
+private:
+  std::map<uint32_t, coco::Bag *> _bag_ctx;
+};
+
+/**
  * @brief Class to store context to build IR from tflite
  */
 class GraphBuilderContext