From 8894b497d514647baf51015aa8d03620c4b8e5b1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Denis=20Maksimenko/AI=20Tools=20Lab=20/SRR/Assistant=20Engi?= =?utf8?q?neer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 15 Aug 2018 15:13:51 +0300 Subject: [PATCH] [nnc] TFlite import: Implement FullyConnected (#1012) TFlite importer can now handle FullyConnected op. Signed-off-by: Denis Maksimenko --- .../nnc/libs/frontend/tflite/include/tflite_op_creator.h | 2 ++ contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp | 7 +++++++ contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp | 13 +++++++++++++ 3 files changed, 22 insertions(+) diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h b/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h index 20dd975..c671a27 100644 --- a/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h +++ b/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h @@ -56,6 +56,8 @@ public: const SoftmaxOptions *opts); std::vector createReshape(InputOps inputs, InputParams params, const ReshapeOptions *opts); + std::vector createFullyConnected(InputOps inputs, InputParams params, + const FullyConnectedOptions *opts); private: Graph *graph = nullptr; diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp index f260352..5da0232 100644 --- a/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp +++ b/contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp @@ -94,6 +94,9 @@ void IrVisitor::visit(const Operator *op) case BuiltinOperator_RESHAPE: outputs = opCreator->createReshape(inputs, params, op->builtin_options_as()); break; + case BuiltinOperator_FULLY_CONNECTED: + outputs = opCreator->createFullyConnected(inputs, params, op->builtin_options_as()); + break; case BuiltinOperator_SOFTMAX: outputs = opCreator->createSoftmax(inputs, params, op->builtin_options_as()); break; @@ -163,6 +166,10 @@ std::vector> IrVisitor::createOpParams(const Operator // don't forget to change this if tensor shape processing architecture changes. paramsForOp.emplace_back(transposeTensor<1, 2, 3, 0>(tensor)); } + else if (opcode == BuiltinOperator_FULLY_CONNECTED && t->shape()->size() == 2) + { + paramsForOp.emplace_back(transposeTensor<1, 0>(tensor)); + } else { paramsForOp.push_back(tensor); diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp index 9e87d20..629aa9a 100644 --- a/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp +++ b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp @@ -3,6 +3,7 @@ #include "nnc/core/IR/model/operations/concat_op.h" #include "nnc/core/IR/model/operations/conv_2d_op.h" #include "nnc/core/IR/model/operations/depthwise_conv2d_op.h" +#include "nnc/core/IR/model/operations/fully_connected_op.h" #include "nnc/core/IR/model/operations/relu_op.h" #include "nnc/core/IR/model/operations/capped_relu_op.h" #include "nnc/core/IR/model/operations/softmax_op.h" @@ -92,6 +93,18 @@ std::vector OpCreator::createReshape(InputOps inputs, InputParams pa return outputs; } +std::vector OpCreator::createFullyConnected(InputOps &inputs, InputParams ¶ms, + const FullyConnectedOptions *opts) +{ + // Add Reshape operation to make sure the input for FC operation has shape [1, fcInputSize] + auto outputs = createOp(inputs, ActivationFunctionType_NONE); + uint32_t fcInputSize = params[0]->getShape().dim(0); + outputs[0]->getOperation()->setOutputShape(0, {1, fcInputSize}); + + auto fcOutputs = createOp(outputs, ActivationFunctionType_NONE, std::move(*params[0])); + return createOp(fcOutputs, opts->fused_activation_function(), std::move(*params[1])); +} + INode::Ref OpCreator::addFusedActivation(INode::Ref input, ActivationFunctionType activationType) { INode::Ref activation; -- 2.7.4