Enable HSwish - ngraph op, fusion/decomposition and reference implementation (#1770)
authorKatarzyna Mitrus <katarzyna.mitrus@intel.com>
Wed, 19 Aug 2020 05:04:00 +0000 (07:04 +0200)
committerGitHub <noreply@github.com>
Wed, 19 Aug 2020 05:04:00 +0000 (08:04 +0300)
* Add HSwish operator to nGraph

* Add HSwishFusion transformation

* Update check_constant function

* Add reference implementation for HSwish

* Enable reference implemenation in HSwish evaluate

* Add op_eval test

* HSwish fusion transformation test

* Add HSwishFusionWithoutRelu transformation

* Add more hswish fusion tests

* Register HSwishFusion pass in common_optimizations

* Update HSwish reference implementation

* Add HSwishFusion with Relu and Multiply

* Add HSwishDecomposition transformation pass

* Add HSwishDecomposition test

* Add HSwish op to ngraph python API

* Update HSwish fusion transformations

* Remove HSwishFusion from common optimizations

* Update hswish python API

* Add bf16 to evaluate hswish

* Update hswish python API

* Move hswish reference implementation

* UnaryElementwiseArithmetic inheritance

* Enable HSwish callback for clDNN

* Register HSwishDecomposition pass in ConvertOpSet1ToLegacy

* Enable HSwishFusion pass in common optimizations

* Use NGRAPH_RTTI_DECLARATION

* Moved python hswish test to the test_ops_unary

21 files changed:
inference-engine/src/cldnn_engine/cldnn_engine.cpp
inference-engine/src/transformations/include/transformations/hswish_decomposition.hpp [new file with mode: 0644]
inference-engine/src/transformations/include/transformations/hswish_fusion.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp
inference-engine/src/transformations/src/transformations/hswish_decomposition.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/hswish_fusion.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/hswish_decomposition_test.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp [new file with mode: 0644]
ngraph/core/include/ngraph/op/hswish.hpp [new file with mode: 0644]
ngraph/core/include/ngraph/ops.hpp
ngraph/core/include/ngraph/opsets/opset4_tbl.hpp
ngraph/core/reference/include/ngraph/runtime/reference/hswish.hpp [new file with mode: 0644]
ngraph/core/src/op/hswish.cpp [new file with mode: 0644]
ngraph/python/src/ngraph/__init__.py
ngraph/python/src/ngraph/opset4/__init__.py
ngraph/python/src/ngraph/opset4/ops.py
ngraph/python/tests/test_ngraph/test_ops_unary.py
ngraph/test/CMakeLists.txt
ngraph/test/op_eval/hswish.cpp [new file with mode: 0644]
ngraph/test/type_prop/hswish.cpp [new file with mode: 0644]

index df81cba..1d3dcf5 100644 (file)
@@ -91,6 +91,7 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
                    std::dynamic_pointer_cast<const ::ngraph::opset3::ShuffleChannels>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset2::SpaceToBatch>(node) ||
+                   std::dynamic_pointer_cast<const ::ngraph::opset4::HSwish>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL1>(node) ||
                    std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL2>(node);
         };
diff --git a/inference-engine/src/transformations/include/transformations/hswish_decomposition.hpp b/inference-engine/src/transformations/include/transformations/hswish_decomposition.hpp
new file mode 100644 (file)
index 0000000..aab0799
--- /dev/null
@@ -0,0 +1,25 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API HSwishDecomposition;
+
+}  // namespace pass
+}  // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief HSwishDecomposition transformation into sub-graph x * (min(Relu(x + 3), 6) * const(1/6).
+ */
+class ngraph::pass::HSwishDecomposition: public ngraph::pass::MatcherPass {
+public:
+    HSwishDecomposition();
+};
diff --git a/inference-engine/src/transformations/include/transformations/hswish_fusion.hpp b/inference-engine/src/transformations/include/transformations/hswish_fusion.hpp
new file mode 100644 (file)
index 0000000..5debcab
--- /dev/null
@@ -0,0 +1,63 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API HSwishFusion;
+class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
+class TRANSFORMATIONS_API HSwishFusionWithReluMul;
+class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
+
+
+}  // namespace pass
+}  // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
+ */
+class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
+public:
+    HSwishFusion() {
+        add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
+        add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
+        add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
+    }
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSwish op.
+ */
+ class ngraph::pass::HSwishFusionWithReluDiv: public ngraph::pass::MatcherPass {
+public:
+    HSwishFusionWithReluDiv();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
+ */
+ class ngraph::pass::HSwishFusionWithReluMul: public ngraph::pass::MatcherPass {
+public:
+    HSwishFusionWithReluMul();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief HSwishFusion transformation replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSwish op.
+ */
+ class ngraph::pass::HSwishFusionWithoutRelu: public ngraph::pass::MatcherPass {
+public:
+    HSwishFusionWithoutRelu();
+};
index 5fc53d7..dbb283b 100644 (file)
@@ -14,6 +14,7 @@
 #include "transformations/itt.hpp"
 #include "transformations/mish_fusion.hpp"
 #include "transformations/swish_fusion.hpp"
+#include "transformations/hswish_fusion.hpp"
 
 #include <ngraph/pass/manager.hpp>
 #include <ngraph/pass/nop_elimination.hpp>
@@ -37,6 +38,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
     manager.register_pass<ngraph::pass::MishFusion>();
     manager.register_pass<ngraph::pass::SwishFusion>();
+    manager.register_pass<ngraph::pass::HSwishFusion>();
     manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
 
     manager.set_callback(m_transformation_callback);
index e390b4a..1d6c5dd 100644 (file)
@@ -47,6 +47,7 @@
 #include <transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.hpp>
 #include <transformations/lin_op_sequence_fusoin.hpp>
 #include <transformations/common_optimizations/conv_mul_fusion.hpp>
+#include <transformations/hswish_decomposition.hpp>
 #include <transformations/reduce_l1_decomposition.hpp>
 #include <transformations/reduce_l2_decomposition.hpp>
 
@@ -71,6 +72,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     manager.register_pass<ngraph::pass::ReduceL1Decomposition>();
     manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
 
+    // HSwishDecomposition produce Minimum, Relu and Multiply operations
+    // so it must be executed before
+    manager.register_pass<ngraph::pass::HSwishDecomposition>();
+
     // List if Decomposition and Conversion transformations that can be
     // applied simultaneously in a single graph traversal
     auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
diff --git a/inference-engine/src/transformations/src/transformations/hswish_decomposition.cpp b/inference-engine/src/transformations/src/transformations/hswish_decomposition.cpp
new file mode 100644 (file)
index 0000000..e1200ea
--- /dev/null
@@ -0,0 +1,44 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/hswish_decomposition.hpp"
+
+#include <memory>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::HSwishDecomposition::HSwishDecomposition() {
+    // Decomposes HSwish(x) op into sub-graph x * (min(Relu(x + 3), 6) * const(1/6)
+    auto hswish = ngraph::pattern::wrap_type<opset4::HSwish>();
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto hswish_node = pattern_to_output.at(hswish).get_node_shared_ptr();
+
+        if (m_transformation_callback(hswish_node)) {
+            return false;
+        }
+
+        auto input_type = hswish_node->input_value(0).get_element_type();
+        auto add_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(hswish_node->input_value(0), add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul_first = std::make_shared<ngraph::opset4::Multiply>(hswish_node->input_value(0), min);
+        auto mul_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {(1.0/6.0)});  // const(1/6)
+        auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
+
+        mul_second->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info(hswish_node,
+                                  {add_constant, add, relu, min_constant, min, mul_first, mul_constant, mul_second});
+        ngraph::replace_node(m.get_match_root(), mul_second);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(hswish, "HSwishDecomposition");
+    register_matcher(m, callback);
+}
diff --git a/inference-engine/src/transformations/src/transformations/hswish_fusion.cpp b/inference-engine/src/transformations/src/transformations/hswish_fusion.cpp
new file mode 100644 (file)
index 0000000..60af25e
--- /dev/null
@@ -0,0 +1,180 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/hswish_fusion.hpp"
+
+#include <memory>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,
+                          const float value,
+                          float epsilon = std::numeric_limits<float>::epsilon()) {
+    if (!constant) {
+        return false;
+    }
+    if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
+        auto data = constant->cast_vector<float>();
+        if (data.size() != 1 || std::fabs(data[0] - value) > epsilon) {
+            return false;
+        }
+    } else {
+        return false;
+    }
+    return true;
+}
+
+ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
+    // Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
+    auto input = ngraph::pattern::any_input();
+    auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+    auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+    auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+    auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
+    auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto x_output = pattern_to_output.at(input);
+
+        auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+        auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
+        auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
+
+        bool valid_constant_values = check_constant_value(add_const_value, 3.0)
+                                        && check_constant_value(min_const_value, 6.0)
+                                        && check_constant_value(div_const_value, 6.0);
+
+        if (!valid_constant_values) {
+            return false;
+        }
+
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
+
+        hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(add).get_node_shared_ptr(),
+                                    pattern_to_output.at(relu).get_node_shared_ptr(),
+                                    pattern_to_output.at(min_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(min).get_node_shared_ptr(),
+                                    pattern_to_output.at(mul).get_node_shared_ptr(),
+                                    pattern_to_output.at(div_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(div).get_node_shared_ptr(),
+                                   },
+                                  hswish);
+        ngraph::replace_node(m.get_match_root(), hswish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(div, "HSwishWithReluDivFusion");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
+    // Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
+    auto input = ngraph::pattern::any_input();
+    auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+    auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+    auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+    auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
+    auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto x_output = pattern_to_output.at(input);
+
+        auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+        auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
+        auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
+
+        bool valid_constant_values = check_constant_value(add_const_value, 3.0)
+                                        && check_constant_value(min_const_value, 6.0)
+                                        && check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
+
+        if (!valid_constant_values) {
+            return false;
+        }
+
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
+
+        hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(add).get_node_shared_ptr(),
+                                    pattern_to_output.at(relu).get_node_shared_ptr(),
+                                    pattern_to_output.at(min_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(min).get_node_shared_ptr(),
+                                    pattern_to_output.at(mul_first).get_node_shared_ptr(),
+                                    pattern_to_output.at(mul_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(mul_second).get_node_shared_ptr()
+                                   },
+                                  hswish);
+        ngraph::replace_node(m.get_match_root(), hswish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, "HSwishWithReluMulFusion");
+    register_matcher(m, callback);
+}
+
+
+ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
+    // Replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSwish op.
+    auto input = ngraph::pattern::any_input();
+    auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+    auto max_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
+    auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
+    auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
+    auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto x_output = pattern_to_output.at(input);
+
+        auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+        auto max_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(max_constant).get_node_shared_ptr());
+        auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
+        auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
+
+        bool valid_constant_values = check_constant_value(add_const_value, 3.0)
+                                        && check_constant_value(max_const_value, 0.0)
+                                        && check_constant_value(min_const_value, 6.0)
+                                        && check_constant_value(div_const_value, 6.0);
+
+        if (!valid_constant_values) {
+            return false;
+        }
+
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
+
+        hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(add).get_node_shared_ptr(),
+                                    pattern_to_output.at(max_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(max).get_node_shared_ptr(),
+                                    pattern_to_output.at(min_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(min).get_node_shared_ptr(),
+                                    pattern_to_output.at(div_constant).get_node_shared_ptr(),
+                                    pattern_to_output.at(div).get_node_shared_ptr(),
+                                    pattern_to_output.at(mul).get_node_shared_ptr()
+                                   },
+                                  hswish);
+        ngraph::replace_node(m.get_match_root(), hswish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "HSwishWithoutReluFusion");
+    register_matcher(m, callback);
+}
diff --git a/inference-engine/tests/functional/inference_engine/transformations/hswish_decomposition_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/hswish_decomposition_test.cpp
new file mode 100644 (file)
index 0000000..7356a79
--- /dev/null
@@ -0,0 +1,52 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/hswish_decomposition.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, HSwishDecompositionTest) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishDecomposition>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
+        auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
diff --git a/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp
new file mode 100644 (file)
index 0000000..abe5053
--- /dev/null
@@ -0,0 +1,274 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/hswish_fusion.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, HSwishFusionWithReluDivF16) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, HSwishFusionWithReluDivF32) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
+        auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, HSwishFusionWithReluMul) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
+        auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithReluMul>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, HSwishFusionWithoutRelu) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
+        auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, HSwishFusionWithReluMulWrongConstValue) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
+        auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithReluMul>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
+        auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, HSwishFusionWithReluDivWrongConstValue) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
+        auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto relu = std::make_shared<ngraph::opset4::Relu>(add);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
+        auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, HSwishFusionWithoutReluWrongConstValue) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
+        auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
+        auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
+        auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
+        auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
+        auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
+        auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
+        auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
+        auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
+        auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
diff --git a/ngraph/core/include/ngraph/op/hswish.hpp b/ngraph/core/include/ngraph/op/hswish.hpp
new file mode 100644 (file)
index 0000000..f7d4261
--- /dev/null
@@ -0,0 +1,53 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "ngraph/op/op.hpp"
+#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
+
+namespace ngraph
+{
+    namespace op
+    {
+        namespace v4
+        {
+            /// \brief A HSwish Activation Function
+            /// f(x) =  x * min(max(x + 3, 0), 6) / 6 or
+            /// f(x) = x * min(ReLU(x + 3), 6) / 6
+            ///
+            class NGRAPH_API HSwish : public ngraph::op::util::UnaryElementwiseArithmetic
+            {
+            public:
+                NGRAPH_RTTI_DECLARATION;
+                HSwish() = default;
+
+                /// \brief Constructs a HSwish (hard version of Swish) operation.
+                ///
+                /// \param data Input tensor
+                HSwish(const Output<Node>& arg);
+
+                bool visit_attributes(AttributeVisitor& visitor) override;
+
+                virtual std::shared_ptr<Node>
+                    clone_with_new_inputs(const OutputVector& new_args) const override;
+                bool evaluate(const HostTensorVector& outputs,
+                              const HostTensorVector& inputs) const override;
+            };
+        }
+    }
+}
index af0be2a..44a5dd9 100644 (file)
@@ -75,6 +75,7 @@
 #include "ngraph/op/group_conv.hpp"
 #include "ngraph/op/gru_cell.hpp"
 #include "ngraph/op/hard_sigmoid.hpp"
+#include "ngraph/op/hswish.hpp"
 #include "ngraph/op/interpolate.hpp"
 #include "ngraph/op/less.hpp"
 #include "ngraph/op/less_eq.hpp"
index 7e4fcce..61aae11 100644 (file)
@@ -156,6 +156,7 @@ NGRAPH_OP(Acosh, ngraph::op::v3)
 NGRAPH_OP(Asinh, ngraph::op::v3)
 NGRAPH_OP(Atanh, ngraph::op::v3)
 NGRAPH_OP(CTCLoss, ngraph::op::v4)
+NGRAPH_OP(HSwish, ngraph::op::v4)
 NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
 NGRAPH_OP(Mish, ngraph::op::v4)
 NGRAPH_OP(ReduceL1, ngraph::op::v4)
diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/hswish.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/hswish.hpp
new file mode 100644 (file)
index 0000000..e26ceb4
--- /dev/null
@@ -0,0 +1,38 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <cmath>
+#include <cstddef>
+
+namespace ngraph
+{
+    namespace runtime
+    {
+        namespace reference
+        {
+            template <typename T>
+            void hswish(const T* arg, T* out, size_t count)
+            {
+                for (size_t i = 0; i < count; i++)
+                {
+                    out[i] = arg[i] * std::min<T>(std::max<T>(arg[i] + 3.0f, 0.0f), 6.0f) / 6.0f;
+                }
+            }
+        }
+    }
+}
diff --git a/ngraph/core/src/op/hswish.cpp b/ngraph/core/src/op/hswish.cpp
new file mode 100644 (file)
index 0000000..5dd0794
--- /dev/null
@@ -0,0 +1,78 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "ngraph/op/hswish.hpp"
+#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/constant.hpp"
+
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/runtime/reference/hswish.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+NGRAPH_RTTI_DEFINITION(op::v4::HSwish, "HSwish", 4);
+
+op::v4::HSwish::HSwish(const Output<Node>& arg)
+    : UnaryElementwiseArithmetic(arg)
+{
+    constructor_validate_and_infer_types();
+}
+
+bool op::v4::HSwish::visit_attributes(AttributeVisitor& visitor)
+{
+    return true;
+}
+
+shared_ptr<Node> op::v4::HSwish::clone_with_new_inputs(const OutputVector& new_args) const
+{
+    return make_shared<op::v4::HSwish>(new_args.at(0));
+}
+
+namespace
+{
+    template <element::Type_t ET>
+    inline bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
+    {
+        using T = typename element_type_traits<ET>::value_type;
+
+        runtime::reference::hswish<T>(arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
+        return true;
+    }
+
+    bool evaluate_hswish(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
+    {
+        bool rc = true;
+        out->set_unary(arg);
+
+        switch (arg->get_element_type())
+        {
+            TYPE_CASE(bf16)(arg, out, count);
+            break;
+            TYPE_CASE(f16)(arg, out, count);
+            break;
+            TYPE_CASE(f32)(arg, out, count);
+            break;
+        default: rc = false; break;
+        }
+        return rc;
+    }
+}
+
+bool op::v4::HSwish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
+{
+    return evaluate_hswish(inputs[0], outputs[0], shape_size(get_output_shape(0)));
+}
index ceb9557..19e5b36 100644 (file)
@@ -84,6 +84,7 @@ from ngraph.opset4 import group_convolution
 from ngraph.opset4 import group_convolution_backprop_data
 from ngraph.opset4 import gru_cell
 from ngraph.opset4 import hard_sigmoid
+from ngraph.opset4 import hswish
 from ngraph.opset4 import interpolate
 from ngraph.opset4 import less
 from ngraph.opset4 import less_equal
index 3175363..2980f88 100644 (file)
@@ -72,6 +72,7 @@ from ngraph.opset1.ops import group_convolution
 from ngraph.opset1.ops import group_convolution_backprop_data
 from ngraph.opset3.ops import gru_cell
 from ngraph.opset1.ops import hard_sigmoid
+from ngraph.opset4.ops import hswish
 from ngraph.opset1.ops import interpolate
 from ngraph.opset1.ops import less
 from ngraph.opset1.ops import less_equal
index 23a84e3..00e31b0 100644 (file)
@@ -150,6 +150,16 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
 
 
 @nameable_op
+def hswish(data: NodeInput, name: Optional[str] = None,) -> Node:
+    """Return a node which performs HSwish (hard version of Swish).
+
+    :param data: Tensor with input data floating point type.
+    :return: The new node which performs HSwish
+    """
+    return _get_node_factory_opset4().create("HSwish", as_nodes(data), {})
+
+
+@nameable_op
 def swish(
     data: NodeInput,
     beta: Optional[NodeInput] = None,
index 6d5ccec..92c3415 100644 (file)
@@ -17,6 +17,7 @@ import numpy as np
 import pytest
 
 import ngraph as ng
+from ngraph.impl import Shape, Type
 from tests.test_ngraph.util import run_op_node, run_op_numeric_data
 from tests import xfail_issue_35929, xfail_issue_34323
 
@@ -148,3 +149,14 @@ def test_erf():
 
     result = run_op_numeric_data(input_tensor, ng.erf)
     assert np.allclose(result, expected)
+
+
+def test_hswish():
+    float_dtype = np.float32
+    data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+
+    node = ng.hswish(data)
+    assert node.get_type_name() == "HSwish"
+    assert node.get_output_size() == 1
+    assert list(node.get_output_shape(0)) == [3, 10]
+    assert node.get_output_element_type(0) == Type.f32
index 4f69cae..dc3e444 100644 (file)
@@ -70,6 +70,7 @@ set(SRC
     node_input_output.cpp
     nop_elimination.cpp
     op.cpp
+    op_eval/hswish.cpp
     op_eval/matmul.cpp
     op_eval/mish.cpp
     op_eval/non_zero.cpp
@@ -127,6 +128,7 @@ set(SRC
     type_prop/group_convolution_backprop_data.cpp
     type_prop/gru_cell.cpp
     type_prop/hard_sigmoid.cpp
+    type_prop/hswish.cpp
     type_prop/lrn.cpp
     type_prop/lstm_cell.cpp
     type_prop/lstm_sequence.cpp
diff --git a/ngraph/test/op_eval/hswish.cpp b/ngraph/test/op_eval/hswish.cpp
new file mode 100644 (file)
index 0000000..1de6087
--- /dev/null
@@ -0,0 +1,48 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "ngraph/op/hswish.hpp"
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/validation_util.hpp"
+#include "runtime/backend.hpp"
+#include "util/test_tools.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(op_eval, hswish)
+{
+    auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+    auto swish = make_shared<op::v4::HSwish>(p);
+    auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
+
+    std::vector<float> inputs{-0.5f, 0.0f, 0.5f};
+    std::vector<float> expected_result{-0.208333f, 0.0f, 0.29166667f};
+
+    auto result = make_shared<HostTensor>();
+    ASSERT_TRUE(
+        fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
+    EXPECT_EQ(result->get_element_type(), element::f32);
+    EXPECT_EQ(result->get_shape(), Shape{3});
+    auto result_data = read_vector<float>(result);
+    for (auto i = 0; i < inputs.size(); i++)
+        EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
diff --git a/ngraph/test/type_prop/hswish.cpp b/ngraph/test/type_prop/hswish.cpp
new file mode 100644 (file)
index 0000000..9df6d19
--- /dev/null
@@ -0,0 +1,54 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "gtest/gtest.h"
+#include "ngraph/ngraph.hpp"
+#include "util/type_prop.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(type_prop, hswish)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto hswish_func = make_shared<op::v4::HSwish>(data);
+    EXPECT_EQ(hswish_func->get_element_type(), element::f32);
+    EXPECT_EQ(hswish_func->get_shape(), data->get_output_shape(0));
+}
+
+TEST(type_prop, hswish_partial)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto hswish_func = make_shared<op::v4::HSwish>(data);
+    EXPECT_EQ(hswish_func->get_element_type(), element::f32);
+    ASSERT_TRUE(
+        hswish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
+
+    // rank unknown
+    auto hswish_partial = make_shared<op::v4::HSwish>(
+        make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
+    ASSERT_TRUE(hswish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+}
+
+TEST(type_prop, hswish_partial_static_rank)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto hswish_func = make_shared<op::v4::HSwish>(data);
+    EXPECT_EQ(hswish_func->get_element_type(), element::f32);
+    ASSERT_TRUE(
+        hswish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
+    ASSERT_TRUE(hswish_func->get_output_partial_shape(0).rank().is_static());
+}