#include <loco/IR/CanonicalNode.h>
#include <loco/IR/CanonicalNodeVisitor.h>
+#include <locoex/COpCall.h>
+
+#include <flatbuffers/flexbuffers.h>
using namespace flatbuffers;
using namespace tflite;
void visit(loco::EltwiseSub *) final;
void visit(loco::EltwiseDiv *) final;
+ void visit(locoex::COpCall *);
+
private:
FlatBufferBuilder &builder;
SerializedModelData &gd;
gd._operators.push_back(op_offset);
}
+inline flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
+CreateCOpCallOptions(flatbuffers::FlatBufferBuilder &fbb, locoex::COpCall *copCall)
+{
+ // read attrs in FlexBuffer format and pass them to FlatBuffer builder
+ flexbuffers::Builder flexbuf;
+ {
+ size_t map_start = flexbuf.StartMap();
+
+ // Note: among attrs of COpCall, 'op' and 'name' won't be included into tflite file
+ auto names = copCall->attr_names();
+ for (auto name : names)
+ {
+ if (auto int_val = copCall->attr<locoex::COpAttrType::Int>(name))
+ flexbuf.Int(name.c_str(), int_val->val());
+ else if (auto float_val = copCall->attr<locoex::COpAttrType::Float>(name))
+ flexbuf.Float(name.c_str(), float_val->val());
+ else
+ // TODO Support more attribute types
+ throw std::runtime_error("Not supported type while writing flexbuffer");
+ }
+
+ flexbuf.EndMap(map_start);
+ flexbuf.Finish();
+ }
+
+ auto offset = fbb.CreateVector(flexbuf.GetBuffer());
+
+ return offset;
+}
+
+void OperationExporter::visit(locoex::COpCall *call)
+{
+ // Registering this custom op name into tflite Operator Codes table
+ uint32_t op_idx = gd.registerCustomOpcode(call->op());
+
+ std::vector<int32_t> inputs_vec;
+ {
+ inputs_vec.resize(call->arity());
+ for (uint32_t i = 0; i < call->arity(); i++)
+ inputs_vec[i] = get_tensor_index(call->arg(i));
+ }
+
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(call))};
+
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ auto custom_options = CreateCOpCallOptions(builder, call);
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_NONE, // builtin_options_type
+ 0, // built-in option
+ custom_options, // custom options
+ tflite::CustomOptionsFormat_FLEXBUFFERS);
+
+ gd._operators.push_back(op_offset);
+}
+
void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
SerializedModelData &data)
{
OperationExporter exporter{builder, data};
canonical_node->accept(&exporter);
}
+ else if (dynamic_cast<locoex::COpNode *>(node))
+ {
+ OperationExporter exporter{builder, data};
+ exporter.visit(dynamic_cast<locoex::COpCall *>(node));
+ }
else
{
assert(false && "unsupported node found");