[enco/tfl/frontend] TensorContext returns tensor type (#2533)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Thu, 6 Dec 2018 06:48:14 +0000 (15:48 +0900)
committer박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 6 Dec 2018 06:48:14 +0000 (15:48 +0900)
This commit enables TensorContext to return tensor type.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
contrib/enco/frontend/tflite/src/Context.cpp
contrib/enco/frontend/tflite/src/Context.h

index cfdbc06..ef030dc 100644 (file)
@@ -39,9 +39,11 @@ void TensorContext::prepare(const tflite::SubGraph *graph)
     auto const tensor_info = graph->tensors()->Get(tensor_id);
     auto const tensor_name = tensor_info->name()->str();
     auto const tensor_shape = as_tensor_shape(tensor_info->shape());
+    auto const tensor_type = tensor_info->type();
 
     _name_ctx[tensor_id] = tensor_name;
     _shape_ctx[tensor_id] = tensor_shape;
+    _type_ctx[tensor_id] = tensor_type;
   }
 }
 
index d126a02..f72385f 100644 (file)
@@ -33,7 +33,7 @@ namespace tflimport
 {
 
 /**
- * @brief Extracts and holds operand(tensor) information such as name and shape
+ * @brief Extracts and holds operand(tensor) information such as name, shape, and type
  */
 class TensorContext
 {
@@ -42,10 +42,12 @@ public:
 
   const std::string &name(uint32_t tensor_id) { return _name_ctx[tensor_id]; }
   const tensor::Shape &shape(uint32_t tensor_id) { return _shape_ctx[tensor_id]; }
+  const tflite::TensorType &type(uint32_t tensor_id) { return _type_ctx[tensor_id]; }
 
 private:
   std::map<uint32_t, std::string> _name_ctx;
   std::map<uint32_t, tensor::Shape> _shape_ctx;
+  std::map<uint32_t, tflite::TensorType> _type_ctx;
 };
 
 /**