[Tflite] Implement serializing opcode
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 14 Apr 2021 04:47:12 +0000 (13:47 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 17 May 2021 06:12:38 +0000 (15:12 +0900)
This patch implements implementing serialize opcode by
`buildOperatorCodes`

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/compiler/tflite_interpreter.cpp

index 0bd3547..fe1c875 100644 (file)
@@ -250,6 +250,13 @@ public:
     return index2data[index];
   }
 
+  /**
+   * @brief Get the Data object
+   *
+   * @return const std::vector<T>& underlying data
+   */
+  const std::vector<T> &getData() const { return index2data; }
+
 private:
   std::unordered_map<T, unsigned int> data2index; /**< data -> index map */
   std::vector<T> index2data;                      /**< index -> data map */
@@ -301,6 +308,10 @@ public:
     return std::get<BidirectionalIndexMap<T>>(maps);
   }
 
+  template <typename T> const BidirectionalIndexMap<T> &getIndexMap() const {
+    return std::get<BidirectionalIndexMap<T>>(maps);
+  }
+
 private:
   float empty_buffer[0]; /**< unintialized tensor points to this buffer */
 
@@ -331,10 +342,27 @@ buildBuffers(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
 flatbuffers::Offset<
   flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>
 buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
-  /** NYI! */
-  return flatbuffers::Offset<
-    flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>();
-};
+  const auto &op_codes = map.getIndexMap<tflite::BuiltinOperator>().getData();
+
+  std::vector<flatbuffers::Offset<tflite::OperatorCode>> fb_op_codes;
+  fb_op_codes.reserve(op_codes.size());
+
+  auto create_op_offset = [&fbb](const tflite::BuiltinOperator &op,
+                                 int32_t version = 1) {
+    tflite::OperatorCodeBuilder builder(fbb);
+    builder.add_deprecated_builtin_code(static_cast<int8_t>(op));
+    /// @todo find reason why version field is not shown
+    /// on json when version is 1 (other versions are fine)
+    builder.add_version(version);
+    builder.add_builtin_code(op);
+    return builder.Finish();
+  };
+
+  std::transform(op_codes.begin(), op_codes.end(),
+                 std::back_inserter(fb_op_codes), create_op_offset);
+
+  return fbb.CreateVector(fb_op_codes);
+}
 
 flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::SubGraph>>>
 buildSubGraph(const TfOpNodes &nodes, const TfOpIdxMap &map,
@@ -355,14 +383,14 @@ void TfliteInterpreter::serialize(
   auto opNodes = buildOpNodes(representation);
   TfOpIdxMap map(opNodes); /// build TfOpIdxMap
 
-  auto UNUSED(opcodes) = buildOperatorCodes(map, fbb);
+  auto opcodes = buildOperatorCodes(map, fbb);
   auto UNUSED(buffers) = buildBuffers(map, fbb);
   auto UNUSED(subgraph) = buildSubGraph(opNodes, map, fbb);
   auto desc = fbb.CreateString("This file is generated from NNTrainer");
 
   tflite::ModelBuilder model_builder(fbb);
 
-  WILL(model_builder.add_operator_codes(opcode_offset));
+  model_builder.add_operator_codes(opcodes);
   WILL(model_builder.add_buffers(buffers));
   WILL(model_builder.add_subgraphs(subgraph));
   model_builder.add_version(3);