Enable swish (#1682)
authorEvgeny Lazarev <evgeny.lazarev@intel.com>
Mon, 10 Aug 2020 12:51:21 +0000 (15:51 +0300)
committerGitHub <noreply@github.com>
Mon, 10 Aug 2020 12:51:21 +0000 (15:51 +0300)
* Draft version of the Swish nGraph operation and fusing transformations for different approaches to express the operation

* Swish fusing transformation refactoring

* Added Swish operation and extractor for TF. Removed unfolding transformation for the operation.

* Added SwishIE. Implemented transformation to convert Swish to SwishIE.

* Code style fixes

* Updated Swish reference implementation. Added tests for shape and value inference

* Fixed code style for Python API

* Fixed unit test

* Apply review comments

* Use matcher_pass_callback

* Make m_alpha attribute protected in the SwishIE operation

* Fixed Swish op PythonAPI test

29 files changed:
inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp
inference-engine/src/transformations/include/ngraph_ops/swish_ie.hpp [new file with mode: 0644]
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp [new file with mode: 0644]
inference-engine/src/transformations/include/transformations/swish_fusion.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/ngraph_ops/swish_ie.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/swish_fusion.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp
inference-engine/tests/functional/inference_engine/transformations/swish_fusion_test.cpp [new file with mode: 0644]
model-optimizer/automation/package_BOM.txt
model-optimizer/extensions/front/tf/swish.py [deleted file]
model-optimizer/extensions/front/tf/swish_ext.py [new file with mode: 0644]
model-optimizer/extensions/front/tf/swish_test.py [deleted file]
model-optimizer/extensions/ops/activation_ops.py
ngraph/python/src/ngraph/__init__.py
ngraph/python/src/ngraph/opset4/__init__.py
ngraph/python/src/ngraph/opset4/ops.py
ngraph/python/tests/test_ngraph/test_swish.py [new file with mode: 0644]
ngraph/src/ngraph/CMakeLists.txt
ngraph/src/ngraph/op/swish.cpp [new file with mode: 0644]
ngraph/src/ngraph/op/swish.hpp [new file with mode: 0644]
ngraph/src/ngraph/ops.hpp
ngraph/src/ngraph/opsets/opset4_tbl.hpp
ngraph/src/ngraph/runtime/reference/swish.hpp [new file with mode: 0644]
ngraph/test/CMakeLists.txt
ngraph/test/op_eval/swish.cpp [new file with mode: 0644]
ngraph/test/type_prop/swish.cpp [new file with mode: 0644]

index 1e386b5..7c80fef 100644 (file)
@@ -496,6 +496,16 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
 
     });
 
+    addSpecificCreator({"SwishIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
+        const std::map<std::string, std::string> params) -> CNNLayerPtr {
+        LayerParams attrs = {node->get_friendly_name(), "Swish",
+            details::convertPrecision(node->get_output_element_type(0))};
+        auto res = std::make_shared<InferenceEngine::CNNLayer>(attrs);
+        res->params = params;
+        return res;
+
+    });
+
     addSpecificCreator({"PriorBox"}, [](const std::shared_ptr<::ngraph::Node>& node,
                                        const std::map<std::string, std::string> params) -> CNNLayerPtr {
         THROW_IE_EXCEPTION << "PriorBox operation has a form that is not supported." << node->get_friendly_name()
diff --git a/inference-engine/src/transformations/include/ngraph_ops/swish_ie.hpp b/inference-engine/src/transformations/include/ngraph_ops/swish_ie.hpp
new file mode 100644 (file)
index 0000000..4434ad0
--- /dev/null
@@ -0,0 +1,32 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+
+#include <transformations_visibility.hpp>
+
+#include "ngraph/op/op.hpp"
+
+namespace ngraph {
+namespace op {
+class TRANSFORMATIONS_API SwishIE : public Op {
+public:
+    static constexpr NodeTypeInfo type_info{"SwishIE", 1};
+    const NodeTypeInfo &get_type_info() const override { return type_info; }
+
+    explicit SwishIE(const Output<Node> &input, float alpha = 1.0);
+
+    void validate_and_infer_types() override;
+    bool visit_attributes(AttributeVisitor& visitor) override;
+    std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
+
+    void set_alpha(float alpha);
+    float get_alpha() const;
+protected:
+    float m_alpha;
+};
+}  // namespace op
+}  // namespace ngraph
diff --git a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp
new file mode 100644 (file)
index 0000000..2d60227
--- /dev/null
@@ -0,0 +1,24 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <memory>
+#include <string>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+    class TRANSFORMATIONS_API ConvertSwishToSwishIEMatcher;
+}  // namespace pass
+}  // namespace ngraph
+
+class ngraph::pass::ConvertSwishToSwishIEMatcher: public ngraph::pass::MatcherPass {
+public:
+    ConvertSwishToSwishIEMatcher();
+};
diff --git a/inference-engine/src/transformations/include/transformations/swish_fusion.hpp b/inference-engine/src/transformations/include/transformations/swish_fusion.hpp
new file mode 100644 (file)
index 0000000..d531e78
--- /dev/null
@@ -0,0 +1,73 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API SwishFusion;
+class TRANSFORMATIONS_API SwishFusionWithSigmoid;
+class TRANSFORMATIONS_API SwishFusionWithSigmoidWithBeta;
+class TRANSFORMATIONS_API SwishFusionWithBeta;
+class TRANSFORMATIONS_API SwishFusionWithoutBeta;
+
+}  // namespace pass
+}  // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
+ */
+class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
+public:
+    SwishFusion() {
+        add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
+        add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
+        add_matcher<ngraph::pass::SwishFusionWithBeta>();
+        add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
+    }
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op.
+ */
+ class ngraph::pass::SwishFusionWithSigmoid: public ngraph::pass::MatcherPass {
+public:
+    SwishFusionWithSigmoid();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
+ */
+class ngraph::pass::SwishFusionWithSigmoidWithBeta: public ngraph::pass::MatcherPass {
+public:
+    SwishFusionWithSigmoidWithBeta();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
+ */
+class ngraph::pass::SwishFusionWithBeta: public ngraph::pass::MatcherPass {
+public:
+    SwishFusionWithBeta();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
+ */
+class ngraph::pass::SwishFusionWithoutBeta: public ngraph::pass::MatcherPass {
+public:
+    SwishFusionWithoutBeta();
+};
diff --git a/inference-engine/src/transformations/src/ngraph_ops/swish_ie.cpp b/inference-engine/src/transformations/src/ngraph_ops/swish_ie.cpp
new file mode 100644 (file)
index 0000000..cd04251
--- /dev/null
@@ -0,0 +1,44 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ngraph_ops/swish_ie.hpp"
+
+#include <algorithm>
+#include <memory>
+
+#include "ngraph/util.hpp"
+#include "ngraph/validation_util.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+constexpr NodeTypeInfo op::SwishIE::type_info;
+
+op::SwishIE::SwishIE(const Output<Node> & input, const float alpha)
+        : Op({input}), m_alpha(alpha) {
+    constructor_validate_and_infer_types();
+}
+
+std::shared_ptr<Node> op::SwishIE::clone_with_new_inputs(const OutputVector& new_args) const {
+    check_new_args_count(this, new_args);
+    return make_shared<SwishIE>(new_args.at(0), m_alpha);
+}
+
+bool op::SwishIE::visit_attributes(AttributeVisitor& visitor) {
+    visitor.on_attribute("alpha", m_alpha);
+    return true;
+}
+
+void op::SwishIE::validate_and_infer_types() {
+    set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
+}
+
+void op::SwishIE::set_alpha(float alpha)  {
+    m_alpha = alpha;
+}
+
+float op::SwishIE::get_alpha() const {
+    return m_alpha;
+}
+
index 78cee29..eb5ee57 100644 (file)
@@ -12,6 +12,7 @@
 #include "transformations/init_node_info.hpp"
 #include "transformations/itt.hpp"
 #include "transformations/mish_fusion.hpp"
+#include "transformations/swish_fusion.hpp"
 
 #include <ngraph/pass/manager.hpp>
 #include <ngraph/pass/nop_elimination.hpp>
@@ -34,6 +35,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
     manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
     manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
     manager.register_pass<ngraph::pass::MishFusion>();
+    manager.register_pass<ngraph::pass::SwishFusion>();
 
     manager.set_callback(m_transformation_callback);
     manager.run_passes(f);
index 12a7766..c8226c8 100644 (file)
@@ -32,6 +32,7 @@
 #include <transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
 #include <transformations/convert_subtract.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.hpp>
+#include <transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
 #include <transformations/convert_depth_to_space.hpp>
@@ -129,6 +130,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     anchor->add_matcher<ngraph::pass::ConvertPReLUToReLUIE>();
     anchor->add_matcher<ngraph::pass::ConvertGatherToGatherIEMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertSeluToSeluIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertSwishToSwishIEMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
     anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
     anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.cpp
new file mode 100644 (file)
index 0000000..1a3f538
--- /dev/null
@@ -0,0 +1,46 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp"
+
+#include <memory>
+
+#include <ngraph/opsets/opset4.hpp>
+
+#include <ngraph_ops/swish_ie.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::ConvertSwishToSwishIEMatcher::ConvertSwishToSwishIEMatcher() {
+    auto swish = ngraph::pattern::wrap_type<ngraph::opset4::Swish>();
+
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
+        auto swish = std::dynamic_pointer_cast<ngraph::opset4::Swish> (m.get_match_root());
+        if (!swish) {
+            return false;
+        }
+        float beta_value = 1.0;
+        if (swish->input_values().size() == 2) {
+            auto beta_node = swish->input_value(1).get_node_shared_ptr();
+            auto beta_const = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_node);
+
+            if (!beta_const) {
+                return false;
+            }
+            if (!ngraph::op::util::get_single_value(beta_const, beta_value)) {
+                return false;
+            }
+        }
+
+        auto swish_ie = std::make_shared<ngraph::op::SwishIE>(swish->input(0).get_source_output(), beta_value);
+        swish_ie->set_friendly_name(swish->get_friendly_name());
+        ngraph::copy_runtime_info(swish, swish_ie);
+        ngraph::replace_node(swish, swish_ie);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(swish, "ConvertSwishToSwishIE");
+    this->register_matcher(m, callback);
+}
\ No newline at end of file
diff --git a/inference-engine/src/transformations/src/transformations/swish_fusion.cpp b/inference-engine/src/transformations/src/transformations/swish_fusion.cpp
new file mode 100644 (file)
index 0000000..dcad9d7
--- /dev/null
@@ -0,0 +1,183 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/swish_fusion.hpp"
+
+#include <memory>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
+    if (!constant) {
+        return false;
+    }
+    if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
+        auto data = constant->cast_vector<float>();
+        if (data.size() != 1 || data[0] != 1.0) {
+            return false;
+        }
+    } else {
+        return false;
+    }
+    return true;
+}
+
+bool check_beta_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
+    // check that the constant for beta contains only one distinct element
+    if (!constant) {
+        return false;
+    }
+    if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
+        auto data = constant->cast_vector<float>();
+        if (!std::equal(data.begin() + 1, data.end(), data.begin())) {
+            return false;
+        }
+    } else {
+        return false;
+    }
+    return true;
+}
+
+ngraph::pass::SwishFusionWithSigmoid::SwishFusionWithSigmoid() {
+    // replaces a sub-graphs x * Sigmoid(x) with a Swish op.
+    auto input = ngraph::pattern::any_input();
+    auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(input);
+    auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto exp_input = pattern_to_output.at(input);
+
+        auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
+
+        swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
+                                   pattern_to_output.at(mul).get_node_shared_ptr()},
+                                  swish);
+        ngraph::replace_node(m.get_match_root(), swish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidFusion");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::SwishFusionWithSigmoidWithBeta::SwishFusionWithSigmoidWithBeta() {
+    // replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
+    auto input = ngraph::pattern::any_input();
+    auto beta = ngraph::pattern::any_input();
+    auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+    auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
+    auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto exp_input = pattern_to_output.at(input);
+        auto beta_input = pattern_to_output.at(beta);
+
+        auto beta_constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_input.get_node_shared_ptr());
+        Output<Node> new_beta;
+        if (beta_constant) {
+            if (check_beta_value(beta_constant)) {
+                new_beta = opset4::Constant::create(beta_input.get_element_type(), Shape{}, {beta_constant->cast_vector<float>()[0]});
+            } else {
+                return false;
+            }
+        } else {
+            // if the input is not constant and number of elements is not equal to 1 then we cannot perform fusing
+            if (beta_input.get_partial_shape().is_dynamic() || ngraph::shape_size(beta_input.get_shape()) != 1) {
+                return false;
+            }
+            new_beta = beta_input;
+        }
+
+        auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, new_beta);
+
+        swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
+                                   pattern_to_output.at(mul).get_node_shared_ptr()},
+                                  swish);
+        ngraph::replace_node(m.get_match_root(), swish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidWithBetaFusion");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::SwishFusionWithBeta::SwishFusionWithBeta() {
+    // replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
+    auto input = ngraph::pattern::any_input();
+    auto beta = ngraph::pattern::any_input();
+    auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+    auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
+    auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+    auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
+    auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+        auto &pattern_to_output = m.get_pattern_value_map();
+        auto exp_input = pattern_to_output.at(input);
+
+        auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+        if (!check_constant_value(constant)) {
+            return false;
+        }
+
+        auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, pattern_to_output.at(beta));
+
+        swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({pattern_to_output.at(beta).get_node_shared_ptr(),
+                                   pattern_to_output.at(mul).get_node_shared_ptr(),
+                                   pattern_to_output.at(neg).get_node_shared_ptr(),
+                                   pattern_to_output.at(exp).get_node_shared_ptr(),
+                                   pattern_to_output.at(add_constant).get_node_shared_ptr(),
+                                   pattern_to_output.at(add).get_node_shared_ptr(),
+                                   pattern_to_output.at(div).get_node_shared_ptr()},
+                                  swish);
+        ngraph::replace_node(m.get_match_root(), swish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithBetaFusion");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::SwishFusionWithoutBeta::SwishFusionWithoutBeta() {
+    // replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
+    auto input = ngraph::pattern::any_input();
+    auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+    auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+    auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+    auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
+    auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
+        auto & pattern_to_output = m.get_pattern_value_map();
+        auto exp_input = pattern_to_output.at(input);
+
+        auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+        if (!check_constant_value(constant)) {
+            return false;
+        }
+
+        auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
+
+        swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::copy_runtime_info({pattern_to_output.at(neg).get_node_shared_ptr(),
+                                   pattern_to_output.at(exp).get_node_shared_ptr(),
+                                   pattern_to_output.at(add_constant).get_node_shared_ptr(),
+                                   pattern_to_output.at(add).get_node_shared_ptr(),
+                                   pattern_to_output.at(div).get_node_shared_ptr()},
+                                   swish);
+        ngraph::replace_node(m.get_match_root(), swish);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithoutBetaFusion");
+    register_matcher(m, callback);
+}
index a540cec..b349378 100644 (file)
@@ -6,12 +6,10 @@
 
 #include <string>
 #include <memory>
-#include <queue>
 
 #include <ngraph/function.hpp>
 #include <ngraph/opsets/opset4.hpp>
 #include <ngraph/pass/manager.hpp>
-#include <ngraph/pass/visualize_tree.hpp>
 #include <transformations/mish_fusion.hpp>
 #include <transformations/init_node_info.hpp>
 #include <transformations/utils/utils.hpp>
diff --git a/inference-engine/tests/functional/inference_engine/transformations/swish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/swish_fusion_test.cpp
new file mode 100644 (file)
index 0000000..e6125ae
--- /dev/null
@@ -0,0 +1,206 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/swish_fusion.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, SwishFusionWithBeta) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
+        auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+        auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
+        auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+        auto constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+        auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input, beta});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::SwishFusion>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
+        auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
+        auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithoutBeta) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+        auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+        auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0});
+        auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+        auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::SwishFusion>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto swish = std::make_shared<ngraph::opset4::Swish>(input);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithoutBetaNonOneAddConstant) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+        auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+        auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
+        auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+        auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::SwishFusion>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+        auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+        auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
+        auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+        auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithSigmoid) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto sig = std::make_shared<ngraph::opset4::Sigmoid>(input);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::SwishFusion>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto swish = std::make_shared<ngraph::opset4::Swish>(input);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithSigmoidWithBeta) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
+        auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+        auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, beta});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::SwishFusion>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
+        auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithSigmoidWithBetaConstant) {
+    // test where the beta constant has multiple but the same value
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{3}, {2.0, 2.0, 2.0});
+        auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+        auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
+        auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::SwishFusion>();
+        manager.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
+    }
+
+    {
+        auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+        auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {2.0});
+        auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
index e20c667..5674b91 100644 (file)
@@ -444,7 +444,7 @@ extensions/front/tf/ssd_toolbox_multihead_detection_output.json
 extensions/front/tf/ssd_v2_support.json
 extensions/front/tf/SSDToolboxDetectionOutput.py
 extensions/front/tf/swap_deconv_inputs.py
-extensions/front/tf/swish.py
+extensions/front/tf/swish_ext.py
 extensions/front/tf/SwitchMergeOptimization.py
 extensions/front/tf/TensorArrayExtractors.py
 extensions/front/tf/TensorArrayGatherV3.py
diff --git a/model-optimizer/extensions/front/tf/swish.py b/model-optimizer/extensions/front/tf/swish.py
deleted file mode 100644 (file)
index a77b463..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-"""
- Copyright (C) 2018-2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-      http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-from extensions.ops.activation_ops import Sigmoid
-from extensions.ops.elementwise import Mul
-from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node, Graph
-
-
-class Swish(FrontReplacementOp):
-    op = "swish_f32"
-    enabled = True
-
-    def replace_op(self, graph: Graph, node: Node):
-        mul_node = Mul(graph, {'name': node.name + '/mul_'}).create_node()
-        sigmoid_node = Sigmoid(graph, {'name': node.name + '/sigmoid_'}).create_node()
-
-        # Connect nodes
-        node.in_port(0).get_connection().get_source().connect(mul_node.in_port(0))
-        node.in_port(0).get_connection().get_source().connect(sigmoid_node.in_port(0))
-        sigmoid_node.out_port(0).connect(mul_node.in_port(1))
-
-        # The "explicit" version of the return value is: [(out_node.id, 0)])
-        return [mul_node.id]
diff --git a/model-optimizer/extensions/front/tf/swish_ext.py b/model-optimizer/extensions/front/tf/swish_ext.py
new file mode 100644 (file)
index 0000000..9700877
--- /dev/null
@@ -0,0 +1,29 @@
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from extensions.ops.activation_ops import Swish
+from mo.front.extractor import FrontExtractorOp
+from mo.graph.graph import Node
+
+
+class SwishExtractor(FrontExtractorOp):
+    op = 'swish_f32'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node: Node):
+        Swish.update_node_stat(node, {})
+        return cls.enabled
+
diff --git a/model-optimizer/extensions/front/tf/swish_test.py b/model-optimizer/extensions/front/tf/swish_test.py
deleted file mode 100644 (file)
index 211e042..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-"""
- Copyright (C) 2018-2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
-      http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-import unittest
-
-import numpy as np
-
-from extensions.front.tf.swish import Swish
-from mo.utils.ir_engine.compare_graphs import compare_graphs
-from mo.utils.unittest.graph import build_graph
-
-nodes_attributes = {
-    'placeholder_1': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-    'placeholder_2': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
-    # swish operation
-    'swish': {'kind': 'op', 'op': 'swish_f32'},
-    # Test operation
-    'last': {'type': None, 'value': None, 'kind': 'op', 'op': None},
-    # Add and Mul operations
-    'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
-    'sigmoid': {'value': None, 'type': 'Sigmoid', 'kind': 'op', 'op': 'Sigmoid'},
-}
-
-
-class TestSwish(unittest.TestCase):
-    def test_swish_test_1(self):
-        # Test with two different inputs from two placeholders
-        graph = build_graph(nodes_attributes,
-                            [('placeholder_1', 'swish'),
-                             ('swish', 'last')
-                             ], nodes_with_edges_only=True)
-
-        graph_ref = build_graph(nodes_attributes,
-                                [('placeholder_1', 'sigmoid', {'out': 0}),
-                                 ('placeholder_1', 'mul', {'in': 0, 'out': 0}),
-                                 ('sigmoid', 'mul', {'in': 1}),
-                                 ('mul', 'last'),
-                                 ], nodes_with_edges_only=True)
-
-        graph.stage = 'front'
-        Swish().find_and_replace_pattern(graph)
-
-        (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
-        self.assertTrue(flag, resp)
index ebcd3a7..a05dba3 100644 (file)
@@ -198,9 +198,9 @@ class LeakyReLU(Op):
 
     def __init__(self, graph: Graph, attrs: dict):
         super().__init__(graph, {
-            'type': __class__.op,
-            'op': __class__.op,
-            'infer': __class__.infer,
+            'type': self.op,
+            'op': self.op,
+            'infer': self.infer,
             'in_ports_count': 1,
             'out_ports_count': 1,
         }, attrs)
@@ -265,3 +265,36 @@ class Mish(Activation):
         sp_attrs = {'version': 'opset4'}
         sp_attrs.update(attrs)
         super().__init__(graph, sp_attrs)
+
+
+class Swish(Op):
+    op = 'Swish'
+
+    def __init__(self, graph: Graph, attrs: dict):
+        mandatory_props = {
+            'op': self.op,
+            'type': self.op,
+            'version': 'opset4',
+
+            'infer': self.infer,
+
+            'in_ports_count': 2,
+            'out_ports_count': 1,
+        }
+        super().__init__(graph, mandatory_props, attrs)
+
+    @staticmethod
+    def infer(node: Node):
+        node_name = node.soft_get('name', node.id)
+        node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
+
+        beta = 1.0
+        if node.is_in_port_connected(1):
+            beta = node.in_port(1).data.get_value()
+            if beta is not None:
+                assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
+                beta = beta.item()
+
+        input_value = node.in_port(1).data.get_value()
+        if input_value is not None and beta is not None:
+            node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))
index 08cf718..b3cae2d 100644 (file)
@@ -150,6 +150,7 @@ from ngraph.opset4 import squared_difference
 from ngraph.opset4 import squeeze
 from ngraph.opset4 import strided_slice
 from ngraph.opset4 import subtract
+from ngraph.opset4 import swish
 from ngraph.opset4 import tan
 from ngraph.opset4 import tanh
 from ngraph.opset4 import tensor_iterator
index eac33dd..8dbbf16 100644 (file)
@@ -139,6 +139,7 @@ from ngraph.opset1.ops import squared_difference
 from ngraph.opset1.ops import squeeze
 from ngraph.opset1.ops import strided_slice
 from ngraph.opset1.ops import subtract
+from ngraph.opset4.ops import swish
 from ngraph.opset1.ops import tan
 from ngraph.opset1.ops import tanh
 from ngraph.opset1.ops import tensor_iterator
index b91f4e7..1366360 100644 (file)
@@ -147,3 +147,19 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
     :return: The new node which performs Mish
     """
     return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
+
+
+@nameable_op
+def swish(
+    data: NodeInput,
+    beta: Optional[NodeInput] = None,
+    name: Optional[str] = None,
+) -> Node:
+    """Return a node which performing Swish activation function Swish(x, beta=1.0) = x * sigmoid(x * beta)).
+
+    :param data: Tensor with input data floating point type.
+    :return: The new node which performs Swish
+    """
+    if beta is None:
+        beta = make_constant_node(1.0, np.float32)
+    return _get_node_factory_opset4().create("Swish", as_nodes(data, beta), {})
diff --git a/ngraph/python/tests/test_ngraph/test_swish.py b/ngraph/python/tests/test_ngraph/test_swish.py
new file mode 100644 (file)
index 0000000..e4917e8
--- /dev/null
@@ -0,0 +1,41 @@
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import ngraph as ng
+from ngraph.impl import Shape, Type
+
+
+def test_swish_props_with_beta():
+    float_dtype = np.float32
+    data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+    beta = ng.parameter(Shape([]), dtype=float_dtype, name="beta")
+
+    node = ng.swish(data, beta)
+    assert node.get_type_name() == "Swish"
+    assert node.get_output_size() == 1
+    assert list(node.get_output_shape(0)) == [3, 10]
+    assert node.get_output_element_type(0) == Type.f32
+
+
+def test_swish_props_without_beta():
+    float_dtype = np.float32
+    data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+
+    node = ng.swish(data)
+    assert node.get_type_name() == "Swish"
+    assert node.get_output_size() == 1
+    assert list(node.get_output_shape(0)) == [3, 10]
+    assert node.get_output_element_type(0) == Type.f32
index 9bd860c..cd6f00a 100644 (file)
@@ -332,6 +332,8 @@ set (SRC
     op/subtract.hpp
     op/sum.cpp
     op/sum.hpp
+    op/swish.cpp
+    op/swish.hpp
     op/variadic_split.cpp
     op/variadic_split.hpp
     op/tan.cpp
diff --git a/ngraph/src/ngraph/op/swish.cpp b/ngraph/src/ngraph/op/swish.cpp
new file mode 100644 (file)
index 0000000..e1a8347
--- /dev/null
@@ -0,0 +1,140 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "ngraph/op/swish.hpp"
+#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/constant.hpp"
+
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/runtime/reference/swish.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+constexpr NodeTypeInfo op::v4::Swish::type_info;
+
+op::v4::Swish::Swish(const Output<Node>& arg)
+    : Op({arg})
+{
+    constructor_validate_and_infer_types();
+}
+
+op::v4::Swish::Swish(const Output<Node>& arg, const Output<Node>& beta)
+    : Op({arg, beta})
+{
+    constructor_validate_and_infer_types();
+}
+
+bool op::v4::Swish::visit_attributes(AttributeVisitor& visitor)
+{
+    return true;
+}
+
+void op::v4::Swish::validate_and_infer_types()
+{
+    auto inputs_count = input_values().size();
+    NODE_VALIDATION_CHECK(this,
+                          inputs_count == 1 || inputs_count == 2,
+                          "Swish must have 1 or 2 inputs, but it has: ",
+                          inputs_count);
+
+    if (inputs_count == 2)
+    {
+        NODE_VALIDATION_CHECK(this,
+                              input_value(0).get_element_type() ==
+                                  input_value(1).get_element_type(),
+                              "Swish inputs must have the same type but they are: ",
+                              input_value(0).get_element_type(),
+                              " and ",
+                              input_value(1).get_element_type());
+        if (get_input_partial_shape(1).rank().is_static())
+        {
+            auto beta_rank = get_input_partial_shape(1).rank().get_length();
+            NODE_VALIDATION_CHECK(this,
+                                  beta_rank == 0,
+                                  "Swish input with beta must be scalar but it has rank: ",
+                                  beta_rank);
+        }
+    }
+    set_output_size(1);
+    set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
+}
+
+shared_ptr<Node> op::v4::Swish::clone_with_new_inputs(const OutputVector& new_args) const
+{
+    if (new_args.size() == 1)
+    {
+        return make_shared<op::v4::Swish>(new_args.at(0));
+    }
+    else
+    {
+        return make_shared<op::v4::Swish>(new_args.at(0), new_args.at(1));
+    }
+}
+
+namespace
+{
+    template <element::Type_t ET>
+    inline bool evaluate(const HostTensorPtr& arg0,
+                         const HostTensorPtr& arg1,
+                         const HostTensorPtr& out,
+                         const size_t count)
+    {
+        using T = typename element_type_traits<ET>::value_type;
+        if (arg1 != nullptr)
+        {
+            runtime::reference::swish<T>(
+                arg0->get_data_ptr<ET>(), arg1->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
+        }
+        else
+        {
+            runtime::reference::swish<T>(
+                arg0->get_data_ptr<ET>(), nullptr, out->get_data_ptr<ET>(), count);
+        }
+        return true;
+    }
+
+    bool evaluate_swish(const HostTensorPtr& arg0,
+                        const HostTensorPtr& arg1,
+                        const HostTensorPtr& out,
+                        const size_t count)
+    {
+        bool rc = true;
+        out->set_unary(arg0);
+
+        switch (arg0->get_element_type())
+        {
+            TYPE_CASE(f16)(arg0, arg1, out, count);
+            break;
+            TYPE_CASE(f32)(arg0, arg1, out, count);
+            break;
+        default: rc = false; break;
+        }
+        return rc;
+    }
+}
+
+bool op::v4::Swish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
+{
+    if (inputs.size() == 2)
+    {
+        return evaluate_swish(inputs[0], inputs[1], outputs[0], shape_size(get_output_shape(0)));
+    }
+    else
+    {
+        return evaluate_swish(inputs[0], nullptr, outputs[0], shape_size(get_output_shape(0)));
+    }
+}
diff --git a/ngraph/src/ngraph/op/swish.hpp b/ngraph/src/ngraph/op/swish.hpp
new file mode 100644 (file)
index 0000000..1c0b6ed
--- /dev/null
@@ -0,0 +1,57 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "ngraph/op/op.hpp"
+
+namespace ngraph
+{
+    namespace op
+    {
+        namespace v4
+        {
+            /// \brief A Swish Activation Function
+            /// f(x) =  x / (1.0 + exp(-beta * x)) or
+            /// f(x) = x * sigmoid(beta * x)
+            ///
+            class NGRAPH_API Swish : public ngraph::op::Op
+            {
+            public:
+                static constexpr NodeTypeInfo type_info{"Swish", 4};
+                const NodeTypeInfo& get_type_info() const override { return type_info; }
+                Swish() = default;
+
+                /// \brief Constructs an Swish operation.
+                ///
+                /// \param data Input tensor
+                /// \param beta Scalar with beta value. If the argument is not specified then use
+                /// the default value 1.0
+                Swish(const Output<Node>& arg, const Output<Node>& beta);
+                explicit Swish(const Output<Node>& arg);
+
+                bool visit_attributes(AttributeVisitor& visitor) override;
+                void validate_and_infer_types() override;
+
+                virtual std::shared_ptr<Node>
+                    clone_with_new_inputs(const OutputVector& new_args) const override;
+                bool evaluate(const HostTensorVector& outputs,
+                              const HostTensorVector& inputs) const override;
+            };
+        }
+    }
+}
index 8cb45e3..ca5d940 100644 (file)
 #include "ngraph/op/strided_slice.hpp"
 #include "ngraph/op/subtract.hpp"
 #include "ngraph/op/sum.hpp"
+#include "ngraph/op/swish.hpp"
 #include "ngraph/op/tan.hpp"
 #include "ngraph/op/tanh.hpp"
 #include "ngraph/op/tensor_iterator.hpp"
index 975567b..4b0dd22 100644 (file)
@@ -155,6 +155,7 @@ NGRAPH_OP(TopK, ngraph::op::v3)
 NGRAPH_OP(Acosh, ngraph::op::v3)
 NGRAPH_OP(Asinh, ngraph::op::v3)
 NGRAPH_OP(Atanh, ngraph::op::v3)
+NGRAPH_OP(CTCLoss, ngraph::op::v4)
 NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
 NGRAPH_OP(Mish, ngraph::op::v4)
-NGRAPH_OP(CTCLoss, ngraph::op::v4)
+NGRAPH_OP(Swish, ngraph::op::v4)
diff --git a/ngraph/src/ngraph/runtime/reference/swish.hpp b/ngraph/src/ngraph/runtime/reference/swish.hpp
new file mode 100644 (file)
index 0000000..14bb42e
--- /dev/null
@@ -0,0 +1,43 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <cmath>
+#include <cstddef>
+
+namespace ngraph
+{
+    namespace runtime
+    {
+        namespace reference
+        {
+            template <typename T>
+            void swish(const T* arg, const T* beta, T* out, size_t count)
+            {
+                T beta_value = static_cast<T>(1.0);
+                if (beta != nullptr)
+                {
+                    beta_value = beta[0];
+                }
+                for (size_t i = 0; i < count; i++)
+                {
+                    out[i] = arg[i] / (1.0 + std::exp(-arg[i] * beta_value));
+                }
+            }
+        }
+    }
+}
index 3730265..2447040 100644 (file)
@@ -77,6 +77,7 @@ set(SRC
     op_eval/non_zero.cpp
     op_eval/split.cpp
     op_eval/strided_slice.cpp
+    op_eval/swish.cpp
     op_is.cpp
     opset1.cpp
     partial_shape.cpp
@@ -165,6 +166,7 @@ set(SRC
     type_prop/squared_difference.cpp
     type_prop/squeeze.cpp
     type_prop/sum.cpp
+    type_prop/swish.cpp
     type_prop/reduce_prod.cpp
     type_prop/reduce_sum.cpp
     type_prop/tile.cpp
diff --git a/ngraph/test/op_eval/swish.cpp b/ngraph/test/op_eval/swish.cpp
new file mode 100644 (file)
index 0000000..26997df
--- /dev/null
@@ -0,0 +1,90 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "ngraph/op/swish.hpp"
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/validation_util.hpp"
+#include "runtime/backend.hpp"
+#include "util/test_tools.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(op_eval, swish_with_beta1)
+{
+    auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+    auto beta = make_shared<op::Parameter>(element::f32, Shape{});
+    auto swish = make_shared<op::v4::Swish>(p, beta);
+    auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
+
+    std::vector<float> inputs{-0.5, 0.0, 0.5};
+    std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
+
+    auto result = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result},
+                              {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
+                               make_host_tensor<element::Type_t::f32>(Shape{}, {1.0})}));
+    EXPECT_EQ(result->get_element_type(), element::f32);
+    EXPECT_EQ(result->get_shape(), Shape{3});
+    auto result_data = read_vector<float>(result);
+    for (auto i = 0; i < inputs.size(); i++)
+        EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
+
+TEST(op_eval, swish_with_beta0_75)
+{
+    auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+    auto beta = make_shared<op::Parameter>(element::f32, Shape{});
+    auto swish = make_shared<op::v4::Swish>(p, beta);
+    auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
+
+    std::vector<float> inputs{-0.5, 0.0, 0.5};
+    std::vector<float> expected_result{-0.2036667, 0.0, 0.2963333};
+
+    auto result = make_shared<HostTensor>();
+    ASSERT_TRUE(fun->evaluate({result},
+                              {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
+                               make_host_tensor<element::Type_t::f32>(Shape{}, {0.75})}));
+    EXPECT_EQ(result->get_element_type(), element::f32);
+    EXPECT_EQ(result->get_shape(), Shape{3});
+    auto result_data = read_vector<float>(result);
+    for (auto i = 0; i < inputs.size(); i++)
+        EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
+
+TEST(op_eval, swish_without_beta)
+{
+    auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+    auto swish = make_shared<op::v4::Swish>(p);
+    auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
+
+    std::vector<float> inputs{-0.5, 0.0, 0.5};
+    std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
+
+    auto result = make_shared<HostTensor>();
+    ASSERT_TRUE(
+        fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
+    EXPECT_EQ(result->get_element_type(), element::f32);
+    EXPECT_EQ(result->get_shape(), Shape{3});
+    auto result_data = read_vector<float>(result);
+    for (auto i = 0; i < inputs.size(); i++)
+        EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
diff --git a/ngraph/test/type_prop/swish.cpp b/ngraph/test/type_prop/swish.cpp
new file mode 100644 (file)
index 0000000..6611009
--- /dev/null
@@ -0,0 +1,95 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "gtest/gtest.h"
+#include "ngraph/ngraph.hpp"
+#include "util/type_prop.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(type_prop, swish)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto swish_func = make_shared<op::v4::Swish>(data);
+    EXPECT_EQ(swish_func->get_element_type(), element::f32);
+    EXPECT_EQ(swish_func->get_shape(), data->get_output_shape(0));
+}
+
+TEST(type_prop, swish_partial)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto swish_func = make_shared<op::v4::Swish>(data);
+    EXPECT_EQ(swish_func->get_element_type(), element::f32);
+    ASSERT_TRUE(
+        swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
+
+    // rank unknown
+    auto swish_partial = make_shared<op::v4::Swish>(
+        make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
+    ASSERT_TRUE(swish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+}
+
+TEST(type_prop, swish_partial_static_rank)
+{
+    auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+    auto swish_func = make_shared<op::v4::Swish>(data);
+    EXPECT_EQ(swish_func->get_element_type(), element::f32);
+    ASSERT_TRUE(
+        swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
+    ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
+}
+
+TEST(type_prop, swish_incompatible_types)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto beta = make_shared<op::Parameter>(element::f16, Shape{});
+    try
+    {
+        const auto swish_func = make_shared<op::v4::Swish>(data, beta);
+        FAIL() << "swish_func node was created with incompatible input data types.";
+    }
+    catch (const NodeValidationFailure& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish inputs must have the same type"));
+    }
+}
+
+TEST(type_prop, swish_beta_not_scalar)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto beta = make_shared<op::Parameter>(element::f32, Shape{1});
+    try
+    {
+        const auto swish_func = make_shared<op::v4::Swish>(data, beta);
+        FAIL() << "swish_func node was created with scalar beta value.";
+    }
+    catch (const NodeValidationFailure& error)
+    {
+        EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish input with beta must be scalar"));
+    }
+}
+
+TEST(type_prop, swish_2_inputs)
+{
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+    auto beta = make_shared<op::Parameter>(element::f32, Shape{});
+    const auto swish_func = make_shared<op::v4::Swish>(data, beta);
+
+    EXPECT_EQ(swish_func->get_element_type(), element::f32);
+    ASSERT_TRUE(swish_func->get_output_partial_shape(0).same_scheme(data->get_output_shape(0)));
+    ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
+}