From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Thu, 25 Jul 2019 07:23:11 +0000 (+0900) Subject: [exo-tflite] Support ReLU6 (#5884) X-Git-Tag: submit/tizen/20190809.050447~399 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=887cdbeb63b0378dc828fa2b789ea9104df6319f;p=platform%2Fcore%2Fml%2Fnnfw.git [exo-tflite] Support ReLU6 (#5884) TFLExporter is now able to export ReLU6 node as T/F Lite RELU6 operation. Signed-off-by: Jonghyun Park --- diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index d88f1f5..22f00f9 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -38,6 +38,7 @@ public: public: void visit(loco::ReLU *) final; + void visit(loco::ReLU6 *) final; void visit(loco::Push *) final { /* DO NOTHING */} void visit(loco::Pull *) final { /* DO NOTHING */} void visit(loco::FeatureEncode *) final; @@ -68,6 +69,17 @@ void OperationExporter::visit(loco::ReLU *node) gd._operators.push_back(op_offset); } +void OperationExporter::visit(loco::ReLU6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6); + std::vector inputs_vec{get_tensor_index(node->input())}; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} + void OperationExporter::visit(loco::MaxPool2D *node) { uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAX_POOL_2D); diff --git a/compiler/exo-tflite/src/ShapeInference.cpp b/compiler/exo-tflite/src/ShapeInference.cpp index 968aa13..3f91531 100644 --- a/compiler/exo-tflite/src/ShapeInference.cpp +++ b/compiler/exo-tflite/src/ShapeInference.cpp @@ -81,6 +81,8 @@ public: NODE(FeatureBiasAdd) #undef NODE // TODO Put all the visit method implementations inside this class declaration + ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; } + private: ShapeContext &gd; }; diff --git a/compiler/exo-tflite/src/TFLExporterImpl.test.cpp b/compiler/exo-tflite/src/TFLExporterImpl.test.cpp index 08fd90d..3d9b921 100644 --- a/compiler/exo-tflite/src/TFLExporterImpl.test.cpp +++ b/compiler/exo-tflite/src/TFLExporterImpl.test.cpp @@ -78,6 +78,39 @@ template <> loco::FeatureDecode *TFLExporterImplTests::make_node(void) } // namespace +TEST_F(TFLExporterImplTests, Relu6) +{ + auto pull = make_node(); + { + pull->dtype(loco::DataType::FLOAT32); + pull->shape({1, 8, 8, 3}); + } + auto relu6 = make_node(); + { + relu6->input(pull); + } + auto push = make_node(); + { + push->from(relu6); + } + + auto input = graph()->inputs()->create(); + { + input->name("input"); + input->node(pull); + } + auto output = graph()->outputs()->create(); + { + output->name("output"); + output->node(push); + } + + exo::TFLExporter::Impl exporter{graph()}; + + // TODO Add more checks + SUCCEED(); +} + /** * What happens when there is a mismatch between generation and execution order!? */