Github#433 Add HardSwish support to TfLiteParser
authorJan Eilers <jan.eilers@arm.com>
Tue, 28 Jul 2020 13:00:06 +0000 (14:00 +0100)
committerJan Eilers <jan.eilers@arm.com>
Tue, 28 Jul 2020 15:43:36 +0000 (16:43 +0100)
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: Ic476f8d80bba080ab459db9e6a59cbafd307d129

src/armnnTfLiteParser/TfLiteParser.cpp
src/armnnTfLiteParser/TfLiteParser.hpp
src/armnnTfLiteParser/test/Activations.cpp

index 6943013..1a44493 100644 (file)
@@ -538,6 +538,7 @@ TfLiteParser::TfLiteParser(const Optional<ITfLiteParser::TfLiteParserOptions>& o
     m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE]              = &TfLiteParser::ParseDequantize;
     m_ParserFunctions[tflite::BuiltinOperator_EXP]                     = &TfLiteParser::ParseExp;
     m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED]         = &TfLiteParser::ParseFullyConnected;
+    m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH]              = &TfLiteParser::ParseHardSwish;
     m_ParserFunctions[tflite::BuiltinOperator_LEAKY_RELU]              = &TfLiteParser::ParseLeakyRelu;
     m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC]                = &TfLiteParser::ParseLogistic;
     m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION]        = &TfLiteParser::ParseL2Normalization;
@@ -1992,7 +1993,7 @@ void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
 
 void TfLiteParser::ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex)
 {
-    ParseActivation(subgraphIndex,operatorIndex, ActivationFunction::LeakyReLu);
+    ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::LeakyReLu);
 }
 
 void TfLiteParser::ParseLogistic(size_t subgraphIndex, size_t operatorIndex)
@@ -2005,6 +2006,10 @@ void TfLiteParser::ParseTanH(size_t subgraphIndex, size_t operatorIndex)
     ParseActivation(subgraphIndex,operatorIndex,ActivationFunction::TanH);
 }
 
+void TfLiteParser::ParseHardSwish(size_t subgraphIndex, size_t operatorIndex)
+{
+    ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::HardSwish);
+}
 
 void TfLiteParser::ParseActivation(size_t subgraphIndex, size_t operatorIndex, ActivationFunction activationType)
 {
@@ -2055,6 +2060,9 @@ void TfLiteParser::ParseActivation(size_t subgraphIndex, size_t operatorIndex, A
             activationDesc.m_A = options->alpha;
             break;
         }
+        case ActivationFunction::HardSwish:
+            layerName += str(boost::format("HARDSWISH:%1%:%2%") % subgraphIndex % operatorIndex);
+            break;
         default:
         {
             throw ParseException(
index c252b0f..6a61150 100644 (file)
@@ -104,6 +104,7 @@ private:
     void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
     void ParseExp(size_t subgraphIndex, size_t operatorIndex);
     void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
+    void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
     void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
     void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
     void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
index e8153a2..e57477e 100644 (file)
@@ -105,4 +105,16 @@ BOOST_FIXTURE_TEST_CASE(ParseTanH, TanHFixture)
         { -0.1f,       -0.2f,         -0.3f,       -0.4f,    0.1f,         0.2f,              0.3f },
         { -0.09966799f, -0.19737528f, -0.29131261f, -0.379949f, 0.09966799f, 0.19737528f, 0.29131261f });
 }
+
+struct HardSwishFixture : ActivationFixture
+{
+    HardSwishFixture() : ActivationFixture("HARD_SWISH", "FLOAT32") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseHardSwish, HardSwishFixture)
+{
+    RunTest<2, armnn::DataType::Float32>(0,
+                                         { -4.0f, -3.0f,        -2.9f,  1.2f,        2.2f, 3.0f, 4.0f },
+                                         { -0.0f, -0.0f, -0.04833334f, 0.84f, 1.90666667f, 3.0f, 4.0f });
+}
 BOOST_AUTO_TEST_SUITE_END()