From 7a8892fe7c208cce76017acbabdd137c872ddadc Mon Sep 17 00:00:00 2001 From: Mohamed Nour Abouelseoud Date: Wed, 9 Jan 2019 14:19:58 +0000 Subject: [PATCH] IVGCVSW-2345 Add Rsqrt support in Tensorflow Parser Change-Id: I7c7b65bd77b06925efdaf2c9c98c30994a12de42 --- CMakeLists.txt | 1 + src/armnnTfParser/TfParser.cpp | 17 ++++++++++++ src/armnnTfParser/TfParser.hpp | 1 + src/armnnTfParser/test/Rsqrt.cpp | 59 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+) create mode 100644 src/armnnTfParser/test/Rsqrt.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 9fedc25..407a51d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -417,6 +417,7 @@ if(BUILD_UNIT_TESTS) src/armnnTfParser/test/RealDiv.cpp src/armnnTfParser/test/Reshape.cpp src/armnnTfParser/test/ResizeBilinear.cpp + src/armnnTfParser/test/Rsqrt.cpp src/armnnTfParser/test/Shape.cpp src/armnnTfParser/test/Softmax.cpp src/armnnTfParser/test/TestDependencies.cpp diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 3d0c72d..90bd992 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -345,6 +345,7 @@ const std::map TfParser::ms_Ope { "Relu6", &TfParser::ParseRelu6 }, { "Reshape", &TfParser::ParseReshape }, { "ResizeBilinear", &TfParser::ParseResizeBilinear }, + { "Rsqrt", &TfParser::ParseRsqrt }, { "Shape", &TfParser::ParseShape }, { "Squeeze", &TfParser::ParseSqueeze }, { "Sigmoid", &TfParser::ParseSigmoid }, @@ -2445,6 +2446,22 @@ ParsedTfOperationPtr TfParser::ParseSigmoid(const tensorflow::NodeDef& nodeDef, return AddActivationLayer(nodeDef, activationDesc); } +ParsedTfOperationPtr TfParser::ParseRsqrt(const tensorflow::NodeDef &nodeDef, + const tensorflow::GraphDef &graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + + IConnectableLayer* const layer = m_Network->AddRsqrtLayer(nodeDef.name().c_str()); + + IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + prevLayerOutputSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(prevLayerOutputSlot.GetTensorInfo()); + + return std::make_unique(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp index b8fab41..4421768 100644 --- a/src/armnnTfParser/TfParser.hpp +++ b/src/armnnTfParser/TfParser.hpp @@ -147,6 +147,7 @@ private: ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseRsqrt(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); diff --git a/src/armnnTfParser/test/Rsqrt.cpp b/src/armnnTfParser/test/Rsqrt.cpp new file mode 100644 index 0000000..6924c06 --- /dev/null +++ b/src/armnnTfParser/test/Rsqrt.cpp @@ -0,0 +1,59 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct RsqrtFixture : public armnnUtils::ParserPrototxtFixture +{ + RsqrtFixture() + { + m_Prototext = "node {\n" + " name: \"input\"\n" + " op: \"Placeholder\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"shape\"\n" + " value {\n" + " shape {\n" + " }\n" + " }\n" + " }\n" + "}\n" + "node {\n" + " name: \"Rsqrt\"\n" + " op: \"Rsqrt\"\n" + " input: \"input\"\n" + " attr {\n" + " key: \"T\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + "}\n"; + + SetupSingleInputSingleOutput({ 2, 2 }, "input", "Rsqrt"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseRsqrt, RsqrtFixture) +{ + RunTest<2>({ 1.f, 4.f, 16.f, 25.f }, { 1.f, 0.5f, 0.25f, 0.2f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseRsqrtZeroNegative, RsqrtFixture) +{ + RunTest<2>({ 0.f, -0.f, -25.f, -16.f }, { INFINITY, -INFINITY, -NAN, -NAN }); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file -- 2.7.4