Support Swish and Mish activations
authorLiubov Batanina <piccione-mail@yandex.ru>
Fri, 6 Dec 2019 08:27:59 +0000 (11:27 +0300)
committerLiubov Batanina <piccione-mail@yandex.ru>
Fri, 6 Dec 2019 08:27:59 +0000 (11:27 +0300)
modules/dnn/src/layers/elementwise_layers.cpp

index 8a0ddcd..3459734 100644 (file)
@@ -579,7 +579,7 @@ struct SwishFunctor
     bool supportBackend(int backendId, int)
     {
         return backendId == DNN_BACKEND_OPENCV ||
-               backendId == DNN_BACKEND_HALIDE;
+               backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;;
     }
 
     void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const
@@ -640,7 +640,8 @@ struct SwishFunctor
 #ifdef HAVE_DNN_NGRAPH
     std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node)
     {
-        CV_Error(Error::StsNotImplemented, "");
+        auto sigmoid = std::make_shared<ngraph::op::Sigmoid>(node);
+        return std::make_shared<ngraph::op::v1::Multiply>(node, sigmoid);
     }
 #endif  // HAVE_DNN_NGRAPH
 
@@ -659,7 +660,7 @@ struct MishFunctor
     bool supportBackend(int backendId, int)
     {
         return backendId == DNN_BACKEND_OPENCV ||
-               backendId == DNN_BACKEND_HALIDE;
+               backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
     }
 
     void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const
@@ -720,7 +721,13 @@ struct MishFunctor
 #ifdef HAVE_DNN_NGRAPH
     std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node)
     {
-        CV_Error(Error::StsNotImplemented, "");
+        float one = 1.0f;
+        auto constant = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, ngraph::Shape{1}, &one);
+        auto exp_node = std::make_shared<ngraph::op::v0::Exp>(node);
+        auto sum = std::make_shared<ngraph::op::v1::Add>(constant, exp_node, ngraph::op::AutoBroadcastType::NUMPY);
+        auto log_node = std::make_shared<ngraph::op::v0::Log>(sum);
+        auto tanh_node = std::make_shared<ngraph::op::Tanh>(log_node);
+        return std::make_shared<ngraph::op::v1::Multiply>(node, tanh_node);
     }
 #endif  // HAVE_DNN_NGRAPH