From 70a41cb3ba69c6340e691c6f7baf6c872e51f74c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dmitry=20Mozolev/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 29 Jun 2018 13:21:52 +0300 Subject: [PATCH] Create Model IR operators in TFLite model visitor (#398) This commit adds the code that makes use of the TFLite-to-Model-IR operator creator to construct Model IR graph. Signed-off-by: Dmitry Mozolev --- .../frontend/tflite/include/tflite_ir_visitor.h | 2 + .../libs/frontend/tflite/src/tflite_ir_visitor.cpp | 44 +++++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_ir_visitor.h b/contrib/nnc/libs/frontend/tflite/include/tflite_ir_visitor.h index 46ef5e1..c3dbdf6 100644 --- a/contrib/nnc/libs/frontend/tflite/include/tflite_ir_visitor.h +++ b/contrib/nnc/libs/frontend/tflite/include/tflite_ir_visitor.h @@ -13,6 +13,7 @@ #include "schema_v3.h" #include "tflite_visitor.h" +#include "tflite_op_creator.h" namespace nncc { @@ -47,6 +48,7 @@ public: private: Graph *graph = nullptr; + std::unique_ptr opCreator; const flatbuffers::Vector> *opcodes = nullptr; const flatbuffers::Vector> *tensors = nullptr; diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp index 52f25db..1264260 100644 --- a/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp +++ b/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp @@ -10,6 +10,7 @@ #include "shape_helper.h" #include "tflite_ir_visitor.h" +#include "tflite_op_creator.h" namespace nncc { @@ -31,6 +32,7 @@ IrVisitor::IrVisitor() // TODO: make this a smart pointer. Note that it requires changing the NNImporter interface, // because currently it returns a void*. graph = new Graph(); + opCreator.reset(new OpCreator(graph)); } void IrVisitor::visit(const Model *m) @@ -62,7 +64,47 @@ void IrVisitor::visit(const SubGraph *s) void IrVisitor::visit(const Operator *op) { - throw std::runtime_error{"Not yet implemented"}; + auto inputs = createOpInputs(op); + auto params = createOpParams(op); + + std::vector outputs; + + unsigned int opcode = (*opcodes)[op->opcode_index()]->builtin_code(); + // TODO: support other NN operator types + switch (opcode) + { + case BuiltinOperator_CONV_2D: + outputs = opCreator->createConv2D(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_DEPTHWISE_CONV_2D: + outputs = opCreator->createDepthConv2D(inputs, params, + op->builtin_options_as()); + break; + case BuiltinOperator_MAX_POOL_2D: + outputs = opCreator->createMaxPool(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_AVERAGE_POOL_2D: + outputs = opCreator->createAvgPool(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_CONCATENATION: + outputs = opCreator->createConcat(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_RESHAPE: + outputs = opCreator->createReshape(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_SOFTMAX: + outputs = opCreator->createSoftmax(inputs, params, op->builtin_options_as()); + break; + default: + throw PluginException( + std::string("Encountered unsupported TFLite operator: ") + + EnumNamesBuiltinOperator()[opcode]); + } + + for (int i = 0; i < op->outputs()->size(); ++i) + { + opsForTensorsTheyOutput[(*(op->outputs()))[i]] = outputs[i]; + } } void IrVisitor::visit(const Tensor *) {} -- 2.7.4