* Draft version of the Swish nGraph operation and fusing transformations for different approaches to express the operation
* Swish fusing transformation refactoring
* Added Swish operation and extractor for TF. Removed unfolding transformation for the operation.
* Added SwishIE. Implemented transformation to convert Swish to SwishIE.
* Code style fixes
* Updated Swish reference implementation. Added tests for shape and value inference
* Fixed code style for Python API
* Fixed unit test
* Apply review comments
* Use matcher_pass_callback
* Make m_alpha attribute protected in the SwishIE operation
* Fixed Swish op PythonAPI test
});
+ addSpecificCreator({"SwishIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
+ const std::map<std::string, std::string> params) -> CNNLayerPtr {
+ LayerParams attrs = {node->get_friendly_name(), "Swish",
+ details::convertPrecision(node->get_output_element_type(0))};
+ auto res = std::make_shared<InferenceEngine::CNNLayer>(attrs);
+ res->params = params;
+ return res;
+
+ });
+
addSpecificCreator({"PriorBox"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
THROW_IE_EXCEPTION << "PriorBox operation has a form that is not supported." << node->get_friendly_name()
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+
+#include <transformations_visibility.hpp>
+
+#include "ngraph/op/op.hpp"
+
+namespace ngraph {
+namespace op {
+class TRANSFORMATIONS_API SwishIE : public Op {
+public:
+ static constexpr NodeTypeInfo type_info{"SwishIE", 1};
+ const NodeTypeInfo &get_type_info() const override { return type_info; }
+
+ explicit SwishIE(const Output<Node> &input, float alpha = 1.0);
+
+ void validate_and_infer_types() override;
+ bool visit_attributes(AttributeVisitor& visitor) override;
+ std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
+
+ void set_alpha(float alpha);
+ float get_alpha() const;
+protected:
+ float m_alpha;
+};
+} // namespace op
+} // namespace ngraph
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <memory>
+#include <string>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+ class TRANSFORMATIONS_API ConvertSwishToSwishIEMatcher;
+} // namespace pass
+} // namespace ngraph
+
+class ngraph::pass::ConvertSwishToSwishIEMatcher: public ngraph::pass::MatcherPass {
+public:
+ ConvertSwishToSwishIEMatcher();
+};
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API SwishFusion;
+class TRANSFORMATIONS_API SwishFusionWithSigmoid;
+class TRANSFORMATIONS_API SwishFusionWithSigmoidWithBeta;
+class TRANSFORMATIONS_API SwishFusionWithBeta;
+class TRANSFORMATIONS_API SwishFusionWithoutBeta;
+
+} // namespace pass
+} // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
+ */
+class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
+public:
+ SwishFusion() {
+ add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
+ add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
+ add_matcher<ngraph::pass::SwishFusionWithBeta>();
+ add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
+ }
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op.
+ */
+ class ngraph::pass::SwishFusionWithSigmoid: public ngraph::pass::MatcherPass {
+public:
+ SwishFusionWithSigmoid();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
+ */
+class ngraph::pass::SwishFusionWithSigmoidWithBeta: public ngraph::pass::MatcherPass {
+public:
+ SwishFusionWithSigmoidWithBeta();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
+ */
+class ngraph::pass::SwishFusionWithBeta: public ngraph::pass::MatcherPass {
+public:
+ SwishFusionWithBeta();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
+ */
+class ngraph::pass::SwishFusionWithoutBeta: public ngraph::pass::MatcherPass {
+public:
+ SwishFusionWithoutBeta();
+};
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ngraph_ops/swish_ie.hpp"
+
+#include <algorithm>
+#include <memory>
+
+#include "ngraph/util.hpp"
+#include "ngraph/validation_util.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+constexpr NodeTypeInfo op::SwishIE::type_info;
+
+op::SwishIE::SwishIE(const Output<Node> & input, const float alpha)
+ : Op({input}), m_alpha(alpha) {
+ constructor_validate_and_infer_types();
+}
+
+std::shared_ptr<Node> op::SwishIE::clone_with_new_inputs(const OutputVector& new_args) const {
+ check_new_args_count(this, new_args);
+ return make_shared<SwishIE>(new_args.at(0), m_alpha);
+}
+
+bool op::SwishIE::visit_attributes(AttributeVisitor& visitor) {
+ visitor.on_attribute("alpha", m_alpha);
+ return true;
+}
+
+void op::SwishIE::validate_and_infer_types() {
+ set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
+}
+
+void op::SwishIE::set_alpha(float alpha) {
+ m_alpha = alpha;
+}
+
+float op::SwishIE::get_alpha() const {
+ return m_alpha;
+}
+
#include "transformations/init_node_info.hpp"
#include "transformations/itt.hpp"
#include "transformations/mish_fusion.hpp"
+#include "transformations/swish_fusion.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/nop_elimination.hpp>
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
manager.set_callback(m_transformation_callback);
manager.run_passes(f);
#include <transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
#include <transformations/convert_subtract.hpp>
#include <transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.hpp>
+#include <transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp>
#include <transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp>
#include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
#include <transformations/convert_depth_to_space.hpp>
anchor->add_matcher<ngraph::pass::ConvertPReLUToReLUIE>();
anchor->add_matcher<ngraph::pass::ConvertGatherToGatherIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertSeluToSeluIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertSwishToSwishIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp"
+
+#include <memory>
+
+#include <ngraph/opsets/opset4.hpp>
+
+#include <ngraph_ops/swish_ie.hpp>
+#include <transformations/utils/utils.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::ConvertSwishToSwishIEMatcher::ConvertSwishToSwishIEMatcher() {
+ auto swish = ngraph::pattern::wrap_type<ngraph::opset4::Swish>();
+
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
+ auto swish = std::dynamic_pointer_cast<ngraph::opset4::Swish> (m.get_match_root());
+ if (!swish) {
+ return false;
+ }
+ float beta_value = 1.0;
+ if (swish->input_values().size() == 2) {
+ auto beta_node = swish->input_value(1).get_node_shared_ptr();
+ auto beta_const = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_node);
+
+ if (!beta_const) {
+ return false;
+ }
+ if (!ngraph::op::util::get_single_value(beta_const, beta_value)) {
+ return false;
+ }
+ }
+
+ auto swish_ie = std::make_shared<ngraph::op::SwishIE>(swish->input(0).get_source_output(), beta_value);
+ swish_ie->set_friendly_name(swish->get_friendly_name());
+ ngraph::copy_runtime_info(swish, swish_ie);
+ ngraph::replace_node(swish, swish_ie);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(swish, "ConvertSwishToSwishIE");
+ this->register_matcher(m, callback);
+}
\ No newline at end of file
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/swish_fusion.hpp"
+
+#include <memory>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
+ if (!constant) {
+ return false;
+ }
+ if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
+ auto data = constant->cast_vector<float>();
+ if (data.size() != 1 || data[0] != 1.0) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ return true;
+}
+
+bool check_beta_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
+ // check that the constant for beta contains only one distinct element
+ if (!constant) {
+ return false;
+ }
+ if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
+ auto data = constant->cast_vector<float>();
+ if (!std::equal(data.begin() + 1, data.end(), data.begin())) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ return true;
+}
+
+ngraph::pass::SwishFusionWithSigmoid::SwishFusionWithSigmoid() {
+ // replaces a sub-graphs x * Sigmoid(x) with a Swish op.
+ auto input = ngraph::pattern::any_input();
+ auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(input);
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+ auto &pattern_to_output = m.get_pattern_value_map();
+ auto exp_input = pattern_to_output.at(input);
+
+ auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
+
+ swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+ ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
+ pattern_to_output.at(mul).get_node_shared_ptr()},
+ swish);
+ ngraph::replace_node(m.get_match_root(), swish);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidFusion");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::SwishFusionWithSigmoidWithBeta::SwishFusionWithSigmoidWithBeta() {
+ // replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
+ auto input = ngraph::pattern::any_input();
+ auto beta = ngraph::pattern::any_input();
+ auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+ auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+ auto &pattern_to_output = m.get_pattern_value_map();
+ auto exp_input = pattern_to_output.at(input);
+ auto beta_input = pattern_to_output.at(beta);
+
+ auto beta_constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_input.get_node_shared_ptr());
+ Output<Node> new_beta;
+ if (beta_constant) {
+ if (check_beta_value(beta_constant)) {
+ new_beta = opset4::Constant::create(beta_input.get_element_type(), Shape{}, {beta_constant->cast_vector<float>()[0]});
+ } else {
+ return false;
+ }
+ } else {
+ // if the input is not constant and number of elements is not equal to 1 then we cannot perform fusing
+ if (beta_input.get_partial_shape().is_dynamic() || ngraph::shape_size(beta_input.get_shape()) != 1) {
+ return false;
+ }
+ new_beta = beta_input;
+ }
+
+ auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, new_beta);
+
+ swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+ ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
+ pattern_to_output.at(mul).get_node_shared_ptr()},
+ swish);
+ ngraph::replace_node(m.get_match_root(), swish);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidWithBetaFusion");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::SwishFusionWithBeta::SwishFusionWithBeta() {
+ // replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
+ auto input = ngraph::pattern::any_input();
+ auto beta = ngraph::pattern::any_input();
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+ auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
+ auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+ auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+ auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
+ auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
+ auto &pattern_to_output = m.get_pattern_value_map();
+ auto exp_input = pattern_to_output.at(input);
+
+ auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+ if (!check_constant_value(constant)) {
+ return false;
+ }
+
+ auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, pattern_to_output.at(beta));
+
+ swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+ ngraph::copy_runtime_info({pattern_to_output.at(beta).get_node_shared_ptr(),
+ pattern_to_output.at(mul).get_node_shared_ptr(),
+ pattern_to_output.at(neg).get_node_shared_ptr(),
+ pattern_to_output.at(exp).get_node_shared_ptr(),
+ pattern_to_output.at(add_constant).get_node_shared_ptr(),
+ pattern_to_output.at(add).get_node_shared_ptr(),
+ pattern_to_output.at(div).get_node_shared_ptr()},
+ swish);
+ ngraph::replace_node(m.get_match_root(), swish);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithBetaFusion");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::SwishFusionWithoutBeta::SwishFusionWithoutBeta() {
+ // replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
+ auto input = ngraph::pattern::any_input();
+ auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+ auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+ auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
+ auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
+ auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
+ auto & pattern_to_output = m.get_pattern_value_map();
+ auto exp_input = pattern_to_output.at(input);
+
+ auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
+ if (!check_constant_value(constant)) {
+ return false;
+ }
+
+ auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
+
+ swish->set_friendly_name(m.get_match_root()->get_friendly_name());
+ ngraph::copy_runtime_info({pattern_to_output.at(neg).get_node_shared_ptr(),
+ pattern_to_output.at(exp).get_node_shared_ptr(),
+ pattern_to_output.at(add_constant).get_node_shared_ptr(),
+ pattern_to_output.at(add).get_node_shared_ptr(),
+ pattern_to_output.at(div).get_node_shared_ptr()},
+ swish);
+ ngraph::replace_node(m.get_match_root(), swish);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithoutBetaFusion");
+ register_matcher(m, callback);
+}
#include <string>
#include <memory>
-#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
-#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/mish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/swish_fusion.hpp>
+#include <transformations/init_node_info.hpp>
+#include <transformations/utils/utils.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+
+TEST(TransformationTests, SwishFusionWithBeta) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
+ auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+ auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
+ auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+ auto constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0});
+ auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+ auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input, beta});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
+ auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
+ auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithoutBeta) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+ auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+ auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0});
+ auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+ auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto swish = std::make_shared<ngraph::opset4::Swish>(input);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithoutBetaNonOneAddConstant) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+ auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+ auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
+ auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+ auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto neg = std::make_shared<ngraph::opset4::Negative>(input);
+ auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
+ auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
+ auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
+ auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithSigmoid) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto sig = std::make_shared<ngraph::opset4::Sigmoid>(input);
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto swish = std::make_shared<ngraph::opset4::Swish>(input);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithSigmoidWithBeta) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
+ auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+ auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, beta});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
+ auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, SwishFusionWithSigmoidWithBetaConstant) {
+ // test where the beta constant has multiple but the same value
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{3}, {2.0, 2.0, 2.0});
+ auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
+ auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
+ auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::SwishFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
+ auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {2.0});
+ auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
extensions/front/tf/ssd_v2_support.json
extensions/front/tf/SSDToolboxDetectionOutput.py
extensions/front/tf/swap_deconv_inputs.py
-extensions/front/tf/swish.py
+extensions/front/tf/swish_ext.py
extensions/front/tf/SwitchMergeOptimization.py
extensions/front/tf/TensorArrayExtractors.py
extensions/front/tf/TensorArrayGatherV3.py
+++ /dev/null
-"""
- Copyright (C) 2018-2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-from extensions.ops.activation_ops import Sigmoid
-from extensions.ops.elementwise import Mul
-from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node, Graph
-
-
-class Swish(FrontReplacementOp):
- op = "swish_f32"
- enabled = True
-
- def replace_op(self, graph: Graph, node: Node):
- mul_node = Mul(graph, {'name': node.name + '/mul_'}).create_node()
- sigmoid_node = Sigmoid(graph, {'name': node.name + '/sigmoid_'}).create_node()
-
- # Connect nodes
- node.in_port(0).get_connection().get_source().connect(mul_node.in_port(0))
- node.in_port(0).get_connection().get_source().connect(sigmoid_node.in_port(0))
- sigmoid_node.out_port(0).connect(mul_node.in_port(1))
-
- # The "explicit" version of the return value is: [(out_node.id, 0)])
- return [mul_node.id]
--- /dev/null
+"""
+ Copyright (C) 2018-2020 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from extensions.ops.activation_ops import Swish
+from mo.front.extractor import FrontExtractorOp
+from mo.graph.graph import Node
+
+
+class SwishExtractor(FrontExtractorOp):
+ op = 'swish_f32'
+ enabled = True
+
+ @classmethod
+ def extract(cls, node: Node):
+ Swish.update_node_stat(node, {})
+ return cls.enabled
+
+++ /dev/null
-"""
- Copyright (C) 2018-2020 Intel Corporation
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-"""
-
-import unittest
-
-import numpy as np
-
-from extensions.front.tf.swish import Swish
-from mo.utils.ir_engine.compare_graphs import compare_graphs
-from mo.utils.unittest.graph import build_graph
-
-nodes_attributes = {
- 'placeholder_1': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- 'placeholder_2': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
- # swish operation
- 'swish': {'kind': 'op', 'op': 'swish_f32'},
- # Test operation
- 'last': {'type': None, 'value': None, 'kind': 'op', 'op': None},
- # Add and Mul operations
- 'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
- 'sigmoid': {'value': None, 'type': 'Sigmoid', 'kind': 'op', 'op': 'Sigmoid'},
-}
-
-
-class TestSwish(unittest.TestCase):
- def test_swish_test_1(self):
- # Test with two different inputs from two placeholders
- graph = build_graph(nodes_attributes,
- [('placeholder_1', 'swish'),
- ('swish', 'last')
- ], nodes_with_edges_only=True)
-
- graph_ref = build_graph(nodes_attributes,
- [('placeholder_1', 'sigmoid', {'out': 0}),
- ('placeholder_1', 'mul', {'in': 0, 'out': 0}),
- ('sigmoid', 'mul', {'in': 1}),
- ('mul', 'last'),
- ], nodes_with_edges_only=True)
-
- graph.stage = 'front'
- Swish().find_and_replace_pattern(graph)
-
- (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
- self.assertTrue(flag, resp)
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
- 'type': __class__.op,
- 'op': __class__.op,
- 'infer': __class__.infer,
+ 'type': self.op,
+ 'op': self.op,
+ 'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
sp_attrs = {'version': 'opset4'}
sp_attrs.update(attrs)
super().__init__(graph, sp_attrs)
+
+
+class Swish(Op):
+ op = 'Swish'
+
+ def __init__(self, graph: Graph, attrs: dict):
+ mandatory_props = {
+ 'op': self.op,
+ 'type': self.op,
+ 'version': 'opset4',
+
+ 'infer': self.infer,
+
+ 'in_ports_count': 2,
+ 'out_ports_count': 1,
+ }
+ super().__init__(graph, mandatory_props, attrs)
+
+ @staticmethod
+ def infer(node: Node):
+ node_name = node.soft_get('name', node.id)
+ node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
+
+ beta = 1.0
+ if node.is_in_port_connected(1):
+ beta = node.in_port(1).data.get_value()
+ if beta is not None:
+ assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
+ beta = beta.item()
+
+ input_value = node.in_port(1).data.get_value()
+ if input_value is not None and beta is not None:
+ node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))
from ngraph.opset4 import squeeze
from ngraph.opset4 import strided_slice
from ngraph.opset4 import subtract
+from ngraph.opset4 import swish
from ngraph.opset4 import tan
from ngraph.opset4 import tanh
from ngraph.opset4 import tensor_iterator
from ngraph.opset1.ops import squeeze
from ngraph.opset1.ops import strided_slice
from ngraph.opset1.ops import subtract
+from ngraph.opset4.ops import swish
from ngraph.opset1.ops import tan
from ngraph.opset1.ops import tanh
from ngraph.opset1.ops import tensor_iterator
:return: The new node which performs Mish
"""
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
+
+
+@nameable_op
+def swish(
+ data: NodeInput,
+ beta: Optional[NodeInput] = None,
+ name: Optional[str] = None,
+) -> Node:
+ """Return a node which performing Swish activation function Swish(x, beta=1.0) = x * sigmoid(x * beta)).
+
+ :param data: Tensor with input data floating point type.
+ :return: The new node which performs Swish
+ """
+ if beta is None:
+ beta = make_constant_node(1.0, np.float32)
+ return _get_node_factory_opset4().create("Swish", as_nodes(data, beta), {})
--- /dev/null
+# ******************************************************************************
+# Copyright 2017-2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ******************************************************************************
+import numpy as np
+import ngraph as ng
+from ngraph.impl import Shape, Type
+
+
+def test_swish_props_with_beta():
+ float_dtype = np.float32
+ data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+ beta = ng.parameter(Shape([]), dtype=float_dtype, name="beta")
+
+ node = ng.swish(data, beta)
+ assert node.get_type_name() == "Swish"
+ assert node.get_output_size() == 1
+ assert list(node.get_output_shape(0)) == [3, 10]
+ assert node.get_output_element_type(0) == Type.f32
+
+
+def test_swish_props_without_beta():
+ float_dtype = np.float32
+ data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
+
+ node = ng.swish(data)
+ assert node.get_type_name() == "Swish"
+ assert node.get_output_size() == 1
+ assert list(node.get_output_shape(0)) == [3, 10]
+ assert node.get_output_element_type(0) == Type.f32
op/subtract.hpp
op/sum.cpp
op/sum.hpp
+ op/swish.cpp
+ op/swish.hpp
op/variadic_split.cpp
op/variadic_split.hpp
op/tan.cpp
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "ngraph/op/swish.hpp"
+#include "ngraph/attribute_visitor.hpp"
+#include "ngraph/op/constant.hpp"
+
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/runtime/reference/swish.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+constexpr NodeTypeInfo op::v4::Swish::type_info;
+
+op::v4::Swish::Swish(const Output<Node>& arg)
+ : Op({arg})
+{
+ constructor_validate_and_infer_types();
+}
+
+op::v4::Swish::Swish(const Output<Node>& arg, const Output<Node>& beta)
+ : Op({arg, beta})
+{
+ constructor_validate_and_infer_types();
+}
+
+bool op::v4::Swish::visit_attributes(AttributeVisitor& visitor)
+{
+ return true;
+}
+
+void op::v4::Swish::validate_and_infer_types()
+{
+ auto inputs_count = input_values().size();
+ NODE_VALIDATION_CHECK(this,
+ inputs_count == 1 || inputs_count == 2,
+ "Swish must have 1 or 2 inputs, but it has: ",
+ inputs_count);
+
+ if (inputs_count == 2)
+ {
+ NODE_VALIDATION_CHECK(this,
+ input_value(0).get_element_type() ==
+ input_value(1).get_element_type(),
+ "Swish inputs must have the same type but they are: ",
+ input_value(0).get_element_type(),
+ " and ",
+ input_value(1).get_element_type());
+ if (get_input_partial_shape(1).rank().is_static())
+ {
+ auto beta_rank = get_input_partial_shape(1).rank().get_length();
+ NODE_VALIDATION_CHECK(this,
+ beta_rank == 0,
+ "Swish input with beta must be scalar but it has rank: ",
+ beta_rank);
+ }
+ }
+ set_output_size(1);
+ set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
+}
+
+shared_ptr<Node> op::v4::Swish::clone_with_new_inputs(const OutputVector& new_args) const
+{
+ if (new_args.size() == 1)
+ {
+ return make_shared<op::v4::Swish>(new_args.at(0));
+ }
+ else
+ {
+ return make_shared<op::v4::Swish>(new_args.at(0), new_args.at(1));
+ }
+}
+
+namespace
+{
+ template <element::Type_t ET>
+ inline bool evaluate(const HostTensorPtr& arg0,
+ const HostTensorPtr& arg1,
+ const HostTensorPtr& out,
+ const size_t count)
+ {
+ using T = typename element_type_traits<ET>::value_type;
+ if (arg1 != nullptr)
+ {
+ runtime::reference::swish<T>(
+ arg0->get_data_ptr<ET>(), arg1->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
+ }
+ else
+ {
+ runtime::reference::swish<T>(
+ arg0->get_data_ptr<ET>(), nullptr, out->get_data_ptr<ET>(), count);
+ }
+ return true;
+ }
+
+ bool evaluate_swish(const HostTensorPtr& arg0,
+ const HostTensorPtr& arg1,
+ const HostTensorPtr& out,
+ const size_t count)
+ {
+ bool rc = true;
+ out->set_unary(arg0);
+
+ switch (arg0->get_element_type())
+ {
+ TYPE_CASE(f16)(arg0, arg1, out, count);
+ break;
+ TYPE_CASE(f32)(arg0, arg1, out, count);
+ break;
+ default: rc = false; break;
+ }
+ return rc;
+ }
+}
+
+bool op::v4::Swish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
+{
+ if (inputs.size() == 2)
+ {
+ return evaluate_swish(inputs[0], inputs[1], outputs[0], shape_size(get_output_shape(0)));
+ }
+ else
+ {
+ return evaluate_swish(inputs[0], nullptr, outputs[0], shape_size(get_output_shape(0)));
+ }
+}
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "ngraph/op/op.hpp"
+
+namespace ngraph
+{
+ namespace op
+ {
+ namespace v4
+ {
+ /// \brief A Swish Activation Function
+ /// f(x) = x / (1.0 + exp(-beta * x)) or
+ /// f(x) = x * sigmoid(beta * x)
+ ///
+ class NGRAPH_API Swish : public ngraph::op::Op
+ {
+ public:
+ static constexpr NodeTypeInfo type_info{"Swish", 4};
+ const NodeTypeInfo& get_type_info() const override { return type_info; }
+ Swish() = default;
+
+ /// \brief Constructs an Swish operation.
+ ///
+ /// \param data Input tensor
+ /// \param beta Scalar with beta value. If the argument is not specified then use
+ /// the default value 1.0
+ Swish(const Output<Node>& arg, const Output<Node>& beta);
+ explicit Swish(const Output<Node>& arg);
+
+ bool visit_attributes(AttributeVisitor& visitor) override;
+ void validate_and_infer_types() override;
+
+ virtual std::shared_ptr<Node>
+ clone_with_new_inputs(const OutputVector& new_args) const override;
+ bool evaluate(const HostTensorVector& outputs,
+ const HostTensorVector& inputs) const override;
+ };
+ }
+ }
+}
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
+#include "ngraph/op/swish.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"
NGRAPH_OP(Acosh, ngraph::op::v3)
NGRAPH_OP(Asinh, ngraph::op::v3)
NGRAPH_OP(Atanh, ngraph::op::v3)
+NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
-NGRAPH_OP(CTCLoss, ngraph::op::v4)
+NGRAPH_OP(Swish, ngraph::op::v4)
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <cmath>
+#include <cstddef>
+
+namespace ngraph
+{
+ namespace runtime
+ {
+ namespace reference
+ {
+ template <typename T>
+ void swish(const T* arg, const T* beta, T* out, size_t count)
+ {
+ T beta_value = static_cast<T>(1.0);
+ if (beta != nullptr)
+ {
+ beta_value = beta[0];
+ }
+ for (size_t i = 0; i < count; i++)
+ {
+ out[i] = arg[i] / (1.0 + std::exp(-arg[i] * beta_value));
+ }
+ }
+ }
+ }
+}
op_eval/non_zero.cpp
op_eval/split.cpp
op_eval/strided_slice.cpp
+ op_eval/swish.cpp
op_is.cpp
opset1.cpp
partial_shape.cpp
type_prop/squared_difference.cpp
type_prop/squeeze.cpp
type_prop/sum.cpp
+ type_prop/swish.cpp
type_prop/reduce_prod.cpp
type_prop/reduce_sum.cpp
type_prop/tile.cpp
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "ngraph/op/swish.hpp"
+#include "ngraph/runtime/host_tensor.hpp"
+#include "ngraph/validation_util.hpp"
+#include "runtime/backend.hpp"
+#include "util/test_tools.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(op_eval, swish_with_beta1)
+{
+ auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+ auto beta = make_shared<op::Parameter>(element::f32, Shape{});
+ auto swish = make_shared<op::v4::Swish>(p, beta);
+ auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
+
+ std::vector<float> inputs{-0.5, 0.0, 0.5};
+ std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
+
+ auto result = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result},
+ {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
+ make_host_tensor<element::Type_t::f32>(Shape{}, {1.0})}));
+ EXPECT_EQ(result->get_element_type(), element::f32);
+ EXPECT_EQ(result->get_shape(), Shape{3});
+ auto result_data = read_vector<float>(result);
+ for (auto i = 0; i < inputs.size(); i++)
+ EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
+
+TEST(op_eval, swish_with_beta0_75)
+{
+ auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+ auto beta = make_shared<op::Parameter>(element::f32, Shape{});
+ auto swish = make_shared<op::v4::Swish>(p, beta);
+ auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
+
+ std::vector<float> inputs{-0.5, 0.0, 0.5};
+ std::vector<float> expected_result{-0.2036667, 0.0, 0.2963333};
+
+ auto result = make_shared<HostTensor>();
+ ASSERT_TRUE(fun->evaluate({result},
+ {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
+ make_host_tensor<element::Type_t::f32>(Shape{}, {0.75})}));
+ EXPECT_EQ(result->get_element_type(), element::f32);
+ EXPECT_EQ(result->get_shape(), Shape{3});
+ auto result_data = read_vector<float>(result);
+ for (auto i = 0; i < inputs.size(); i++)
+ EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
+
+TEST(op_eval, swish_without_beta)
+{
+ auto p = make_shared<op::Parameter>(element::f32, Shape{3});
+ auto swish = make_shared<op::v4::Swish>(p);
+ auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
+
+ std::vector<float> inputs{-0.5, 0.0, 0.5};
+ std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
+
+ auto result = make_shared<HostTensor>();
+ ASSERT_TRUE(
+ fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
+ EXPECT_EQ(result->get_element_type(), element::f32);
+ EXPECT_EQ(result->get_shape(), Shape{3});
+ auto result_data = read_vector<float>(result);
+ for (auto i = 0; i < inputs.size(); i++)
+ EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
+}
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "gtest/gtest.h"
+#include "ngraph/ngraph.hpp"
+#include "util/type_prop.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+TEST(type_prop, swish)
+{
+ auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+ auto swish_func = make_shared<op::v4::Swish>(data);
+ EXPECT_EQ(swish_func->get_element_type(), element::f32);
+ EXPECT_EQ(swish_func->get_shape(), data->get_output_shape(0));
+}
+
+TEST(type_prop, swish_partial)
+{
+ auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+ auto swish_func = make_shared<op::v4::Swish>(data);
+ EXPECT_EQ(swish_func->get_element_type(), element::f32);
+ ASSERT_TRUE(
+ swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
+
+ // rank unknown
+ auto swish_partial = make_shared<op::v4::Swish>(
+ make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
+ ASSERT_TRUE(swish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
+}
+
+TEST(type_prop, swish_partial_static_rank)
+{
+ auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
+ auto swish_func = make_shared<op::v4::Swish>(data);
+ EXPECT_EQ(swish_func->get_element_type(), element::f32);
+ ASSERT_TRUE(
+ swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
+ ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
+}
+
+TEST(type_prop, swish_incompatible_types)
+{
+ auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+ auto beta = make_shared<op::Parameter>(element::f16, Shape{});
+ try
+ {
+ const auto swish_func = make_shared<op::v4::Swish>(data, beta);
+ FAIL() << "swish_func node was created with incompatible input data types.";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish inputs must have the same type"));
+ }
+}
+
+TEST(type_prop, swish_beta_not_scalar)
+{
+ auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+ auto beta = make_shared<op::Parameter>(element::f32, Shape{1});
+ try
+ {
+ const auto swish_func = make_shared<op::v4::Swish>(data, beta);
+ FAIL() << "swish_func node was created with scalar beta value.";
+ }
+ catch (const NodeValidationFailure& error)
+ {
+ EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish input with beta must be scalar"));
+ }
+}
+
+TEST(type_prop, swish_2_inputs)
+{
+ auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
+ auto beta = make_shared<op::Parameter>(element::f32, Shape{});
+ const auto swish_func = make_shared<op::v4::Swish>(data, beta);
+
+ EXPECT_EQ(swish_func->get_element_type(), element::f32);
+ ASSERT_TRUE(swish_func->get_output_partial_shape(0).same_scheme(data->get_output_shape(0)));
+ ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
+}