[exo-tflite] Exporting custom op (#6269)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 6 Aug 2019 08:10:22 +0000 (17:10 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 6 Aug 2019 08:10:22 +0000 (17:10 +0900)
This commit enables exportation of custom op (Operator Codes, input, output, op with attr) into tflite file.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/OperationExporter.cpp

index 433a2ca..b39a2d6 100644 (file)
@@ -21,6 +21,9 @@
 
 #include <loco/IR/CanonicalNode.h>
 #include <loco/IR/CanonicalNodeVisitor.h>
+#include <locoex/COpCall.h>
+
+#include <flatbuffers/flexbuffers.h>
 
 using namespace flatbuffers;
 using namespace tflite;
@@ -57,6 +60,8 @@ public:
   void visit(loco::EltwiseSub *) final;
   void visit(loco::EltwiseDiv *) final;
 
+  void visit(locoex::COpCall *);
+
 private:
   FlatBufferBuilder &builder;
   SerializedModelData &gd;
@@ -388,6 +393,63 @@ void OperationExporter::visit(loco::EltwiseDiv *node)
   gd._operators.push_back(op_offset);
 }
 
+inline flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall)
+{
+  // read attrs in FlexBuffer format and pass them to FlatBuffer builder
+  flexbuffers::Builder flexbuf;
+  {
+    size_t map_start = flexbuf.StartMap();
+
+    // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file
+    auto names = copCall->attr_names();
+    for (auto name : names)
+    {
+      if (auto int_val = copCall->attr<locoex::COpAttrType::Int>(name))
+        flexbuf.Int(name.c_str(), int_val->val());
+      else if (auto float_val = copCall->attr<locoex::COpAttrType::Float>(name))
+        flexbuf.Float(name.c_str(), float_val->val());
+      else
+        // TODO Support more attribute types
+        throw std::runtime_error("Not supported type while writing flexbuffer");
+    }
+
+    flexbuf.EndMap(map_start);
+    flexbuf.Finish();
+  }
+
+  auto offset = fbb.CreateVector(flexbuf.GetBuffer());
+
+  return offset;
+}
+
+void OperationExporter::visit(locoex::COpCall *call)
+{
+  // Registering this custom op name into tflite Operator Codes table
+  uint32_t op_idx = gd.registerCustomOpcode(call->op());
+
+  std::vector<int32_t> inputs_vec;
+  {
+    inputs_vec.resize(call->arity());
+    for (uint32_t i = 0; i < call->arity(); i++)
+      inputs_vec[i] = get_tensor_index(call->arg(i));
+  }
+
+  std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(call))};
+
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+
+  auto custom_options = CreateCOpCallOptions(builder, call);
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+                                  tflite::BuiltinOptions_NONE, // builtin_options_type
+                                  0,                           // built-in option
+                                  custom_options,              // custom options
+                                  tflite::CustomOptionsFormat_FLEXBUFFERS);
+
+  gd._operators.push_back(op_offset);
+}
+
 void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
                 SerializedModelData &data)
 {
@@ -412,6 +474,11 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
     OperationExporter exporter{builder, data};
     canonical_node->accept(&exporter);
   }
+  else if (dynamic_cast<locoex::COpNode *>(node))
+  {
+    OperationExporter exporter{builder, data};
+    exporter.visit(dynamic_cast<locoex::COpCall *>(node));
+  }
   else
   {
     assert(false && "unsupported node found");