[exo-tflite] Registering custom op before exporting op code table (#6235)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 5 Aug 2019 23:21:46 +0000 (08:21 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 5 Aug 2019 23:21:45 +0000 (08:21 +0900)
This adds custom op before exporting op code table.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/CMakeLists.txt
compiler/exo-tflite/requires.cmake
compiler/exo-tflite/src/ExporterUtils.cpp
compiler/exo-tflite/src/ExporterUtils.h

index 5e23ac7..61c84a3 100644 (file)
@@ -33,6 +33,8 @@ target_link_libraries(exo_tflite PUBLIC exo_tflite_fbs)
 target_link_libraries(exo_tflite PUBLIC loco)
 target_link_libraries(exo_tflite PRIVATE stdex)
 target_link_libraries(exo_tflite PRIVATE pepper_strcast)
+target_link_libraries(exo_tflite PRIVATE locoex_customop)
+
 # Let's apply nncc common compile options
 #
 # NOTE This will enable strict compilation (warnings as error).
index 567383d..9815422 100644 (file)
@@ -1,2 +1,3 @@
 require("stdex")
 require("loco")
+require("locoex-customop")
index daa09c9..891e413 100644 (file)
@@ -58,6 +58,14 @@ uint32_t SerializedModelData::registerBuiltinOpcode(tflite::BuiltinOperator buil
   return idx;
 }
 
+uint32_t SerializedModelData::registerCustomOpcode(const std::string &custom_op)
+{
+  tflite::BuiltinOperator custom_code = tflite::BuiltinOperator_CUSTOM;
+  auto idx = registerBuiltinOpcode(custom_code);
+  _custom_operator_codes.emplace(OpCode{custom_code}, custom_op);
+  return idx;
+}
+
 tflite::Padding getOpPadding(const loco::Pad<2> *pad)
 {
   // VALID padding
index e37eb33..4854fca 100644 (file)
@@ -69,6 +69,7 @@ struct SerializedModelData final : public SubGraphContext
   SerializedModelData(const SerializedModelData &) = delete;
 
   std::unordered_map<OpCode, uint32_t> _operator_codes;
+  std::unordered_map<OpCode, std::string> _custom_operator_codes;
   std::vector<flatbuffers::Offset<tflite::Operator>> _operators;
   std::vector<flatbuffers::Offset<tflite::Tensor>> _tensors;
   std::vector<flatbuffers::Offset<tflite::Buffer>> _buffers;
@@ -83,6 +84,7 @@ struct SerializedModelData final : public SubGraphContext
    * @return idx of opcode in table of opcodes (see schema)
    */
   uint32_t registerBuiltinOpcode(tflite::BuiltinOperator builtin_code);
+  uint32_t registerCustomOpcode(const std::string &custom_op);
 };
 
 template <typename Permutation> inline bool isNHWC(Permutation *perm);