From b2a6a3d252fbb370218c33572887a2f6fc837f65 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 6 Aug 2019 17:10:22 +0900 Subject: [PATCH] [exo-tflite] Exporting custom op (#6269) This commit enables exportation of custom op (Operator Codes, input, output, op with attr) into tflite file. Signed-off-by: Hyun Sik Yoon --- compiler/exo-tflite/src/OperationExporter.cpp | 67 +++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index 433a2ca..b39a2d6 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -21,6 +21,9 @@ #include #include +#include + +#include using namespace flatbuffers; using namespace tflite; @@ -57,6 +60,8 @@ public: void visit(loco::EltwiseSub *) final; void visit(loco::EltwiseDiv *) final; + void visit(locoex::COpCall *); + private: FlatBufferBuilder &builder; SerializedModelData &gd; @@ -388,6 +393,63 @@ void OperationExporter::visit(loco::EltwiseDiv *node) gd._operators.push_back(op_offset); } +inline flatbuffers::Offset> +CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall) +{ + // read attrs in FlexBuffer format and pass them to FlatBuffer builder + flexbuffers::Builder flexbuf; + { + size_t map_start = flexbuf.StartMap(); + + // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file + auto names = copCall->attr_names(); + for (auto name : names) + { + if (auto int_val = copCall->attr(name)) + flexbuf.Int(name.c_str(), int_val->val()); + else if (auto float_val = copCall->attr(name)) + flexbuf.Float(name.c_str(), float_val->val()); + else + // TODO Support more attribute types + throw std::runtime_error("Not supported type while writing flexbuffer"); + } + + flexbuf.EndMap(map_start); + flexbuf.Finish(); + } + + auto offset = fbb.CreateVector(flexbuf.GetBuffer()); + + return offset; +} + +void OperationExporter::visit(locoex::COpCall *call) +{ + // Registering this custom op name into tflite Operator Codes table + uint32_t op_idx = gd.registerCustomOpcode(call->op()); + + std::vector inputs_vec; + { + inputs_vec.resize(call->arity()); + for (uint32_t i = 0; i < call->arity(); i++) + inputs_vec[i] = get_tensor_index(call->arg(i)); + } + + std::vector outputs_vec{get_tensor_index(static_cast(call))}; + + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + auto custom_options = CreateCOpCallOptions(builder, call); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_NONE, // builtin_options_type + 0, // built-in option + custom_options, // custom options + tflite::CustomOptionsFormat_FLEXBUFFERS); + + gd._operators.push_back(op_offset); +} + void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &data) { @@ -412,6 +474,11 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, OperationExporter exporter{builder, data}; canonical_node->accept(&exporter); } + else if (dynamic_cast(node)) + { + OperationExporter exporter{builder, data}; + exporter.visit(dynamic_cast(node)); + } else { assert(false && "unsupported node found"); -- 2.7.4