[exo-tflite] Exporting Operator Code table including code of custom op (#6323)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 7 Aug 2019 06:38:31 +0000 (15:38 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 7 Aug 2019 06:38:31 +0000 (15:38 +0900)
Now when tflite opcode table is created, name and code for custom op is included.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/TFLExporterImpl.cpp

index ad7b6f9..69d3ebe 100644 (file)
@@ -25,6 +25,7 @@
 #include <cassert>
 #include <unordered_map>
 #include <string>
+#include <stdexcept>
 
 namespace
 {
@@ -59,13 +60,27 @@ using namespace flatbuffers;
 TFLExporter::Impl::Impl(loco::Graph *graph) { exportGraph(graph); }
 
 Offset<Vector<Offset<OperatorCode>>>
-encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<OpCode, uint32_t> &opcodes)
+encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<OpCode, uint32_t> &opcodes,
+                    std::unordered_map<OpCode, std::string> &custom_opcodes)
 {
   std::vector<Offset<OperatorCode>> operator_codes_vec(opcodes.size());
   for (auto it : opcodes)
   {
     uint32_t idx = it.second;
-    operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode);
+    if (it.first.opcode != BuiltinOperator_CUSTOM)
+    {
+      operator_codes_vec[idx] = CreateOperatorCode(builder, it.first.opcode);
+    }
+    else // custom op
+    {
+      auto opCode = it.first;
+      auto custom_code = custom_opcodes.find(opCode);
+      if (custom_code == custom_opcodes.end())
+        throw std::runtime_error("Cannot find code for custom op");
+
+      operator_codes_vec[idx] =
+          CreateOperatorCode(builder, it.first.opcode, builder.CreateString(custom_code->second));
+    }
   }
   return builder.CreateVector(operator_codes_vec);
 }
@@ -109,7 +124,8 @@ void TFLExporter::Impl::exportGraph(loco::Graph *graph)
   exportNodes(graph, _builder, gd);
 
   // excode operator codes
-  auto operator_codes = encodeOperatorCodes(_builder, gd._operator_codes);
+  auto operator_codes =
+      encodeOperatorCodes(_builder, gd._operator_codes, gd._custom_operator_codes);
 
   // Subgraphs
   Offset<SubGraph> subgraph = exportSubgraph(gd);