Make onnx parser to support TanH / Sigmoid / LeakyRelu layers
authorTee Jung <tee.ty.jung@openedges.com>
Fri, 1 Nov 2019 07:04:42 +0000 (07:04 +0000)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Mon, 4 Nov 2019 09:12:46 +0000 (09:12 +0000)
Signed-off-by: Jung Tae-young tee.ty.jung@openedges.com
Change-Id: I44d24b525b78b8d3fee0197abda7bd667eb04d83

src/armnnOnnxParser/OnnxParser.cpp
src/armnnOnnxParser/OnnxParser.hpp

index 9d374ae..0d0cc25 100644 (file)
@@ -337,7 +337,10 @@ const std::map<std::string, OnnxParser::OperationParsingFunction> OnnxParser::m_
     { "Constant",              &OnnxParser::ParseConstant },
     { "MaxPool",               &OnnxParser::ParseMaxPool },
     { "Reshape",               &OnnxParser::ParseReshape },
+    { "Sigmoid",               &OnnxParser::ParseSigmoid },
+    { "Tanh",                  &OnnxParser::ParseTanh },
     { "Relu",                  &OnnxParser::ParseRelu },
+    { "LeakyRelu",             &OnnxParser::ParseLeakyRelu },
     { "Conv",                  &OnnxParser::ParseConv },
     { "Add",                   &OnnxParser::ParseAdd },
 };
@@ -1083,7 +1086,7 @@ void OnnxParser::ParseReshape(const onnx::NodeProto& node)
     }
 }
 
-void OnnxParser::ParseRelu(const onnx::NodeProto& node)
+void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
 {
     CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
     CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
@@ -1091,7 +1094,7 @@ void OnnxParser::ParseRelu(const onnx::NodeProto& node)
     VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
 
     ActivationDescriptor desc;
-    desc.m_Function = ActivationFunction::ReLu;
+    desc.m_Function = func;
 
     IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
     BOOST_ASSERT(layer != nullptr);
@@ -1107,6 +1110,25 @@ void OnnxParser::ParseRelu(const onnx::NodeProto& node)
     RegisterOutputSlots(layer, {node.output(0)});
 }
 
+void OnnxParser::ParseSigmoid(const onnx::NodeProto& node)
+{
+    ParseActivation(node, ActivationFunction::Sigmoid);
+}
+
+void OnnxParser::ParseTanh(const onnx::NodeProto& node)
+{
+    ParseActivation(node, ActivationFunction::TanH);
+}
+
+void OnnxParser::ParseRelu(const onnx::NodeProto& node)
+{
+    ParseActivation(node, ActivationFunction::ReLu);
+}
+
+void OnnxParser::ParseLeakyRelu(const onnx::NodeProto& node)
+{
+    ParseActivation(node, ActivationFunction::LeakyReLu);
+}
 
 void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
 {
index 91927c2..a467180 100644 (file)
@@ -14,6 +14,7 @@
 namespace armnn
 {
 class TensorInfo;
+enum class ActivationFunction;
 }
 
 namespace armnnOnnxParser
@@ -103,7 +104,12 @@ private:
     void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
 
     void ParseReshape(const onnx::NodeProto& nodeProto);
+
+    void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
+    void ParseSigmoid(const onnx::NodeProto& nodeProto);
+    void ParseTanh(const onnx::NodeProto& nodeProto);
     void ParseRelu(const onnx::NodeProto& nodeProto);
+    void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
 
     void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
     void ParseConv(const onnx::NodeProto& nodeProto);