From: Evgeny Lazarev Date: Mon, 10 Aug 2020 12:51:21 +0000 (+0300) Subject: Enable swish (#1682) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=318d38770b7fa2e0bb9ae9be231ce6bd94f5ca12;p=platform%2Fupstream%2Fdldt.git Enable swish (#1682) * Draft version of the Swish nGraph operation and fusing transformations for different approaches to express the operation * Swish fusing transformation refactoring * Added Swish operation and extractor for TF. Removed unfolding transformation for the operation. * Added SwishIE. Implemented transformation to convert Swish to SwishIE. * Code style fixes * Updated Swish reference implementation. Added tests for shape and value inference * Fixed code style for Python API * Fixed unit test * Apply review comments * Use matcher_pass_callback * Make m_alpha attribute protected in the SwishIE operation * Fixed Swish op PythonAPI test --- diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index 1e386b5..7c80fef 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -496,6 +496,16 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr }); + addSpecificCreator({"SwishIE"}, [](const std::shared_ptr<::ngraph::Node>& node, + const std::map params) -> CNNLayerPtr { + LayerParams attrs = {node->get_friendly_name(), "Swish", + details::convertPrecision(node->get_output_element_type(0))}; + auto res = std::make_shared(attrs); + res->params = params; + return res; + + }); + addSpecificCreator({"PriorBox"}, [](const std::shared_ptr<::ngraph::Node>& node, const std::map params) -> CNNLayerPtr { THROW_IE_EXCEPTION << "PriorBox operation has a form that is not supported." << node->get_friendly_name() diff --git a/inference-engine/src/transformations/include/ngraph_ops/swish_ie.hpp b/inference-engine/src/transformations/include/ngraph_ops/swish_ie.hpp new file mode 100644 index 0000000..4434ad0 --- /dev/null +++ b/inference-engine/src/transformations/include/ngraph_ops/swish_ie.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +#include "ngraph/op/op.hpp" + +namespace ngraph { +namespace op { +class TRANSFORMATIONS_API SwishIE : public Op { +public: + static constexpr NodeTypeInfo type_info{"SwishIE", 1}; + const NodeTypeInfo &get_type_info() const override { return type_info; } + + explicit SwishIE(const Output &input, float alpha = 1.0); + + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const OutputVector &new_args) const override; + + void set_alpha(float alpha); + float get_alpha() const; +protected: + float m_alpha; +}; +} // namespace op +} // namespace ngraph diff --git a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp new file mode 100644 index 0000000..2d60227 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + class TRANSFORMATIONS_API ConvertSwishToSwishIEMatcher; +} // namespace pass +} // namespace ngraph + +class ngraph::pass::ConvertSwishToSwishIEMatcher: public ngraph::pass::MatcherPass { +public: + ConvertSwishToSwishIEMatcher(); +}; diff --git a/inference-engine/src/transformations/include/transformations/swish_fusion.hpp b/inference-engine/src/transformations/include/transformations/swish_fusion.hpp new file mode 100644 index 0000000..d531e78 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/swish_fusion.hpp @@ -0,0 +1,73 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API SwishFusion; +class TRANSFORMATIONS_API SwishFusionWithSigmoid; +class TRANSFORMATIONS_API SwishFusionWithSigmoidWithBeta; +class TRANSFORMATIONS_API SwishFusionWithBeta; +class TRANSFORMATIONS_API SwishFusionWithoutBeta; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief SwishFusion transformation replaces various sub-graphs with a Swish op. + */ +class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite { +public: + SwishFusion() { + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + } +}; + +/** + * @ingroup ie_transformation_common_api + * @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op. + */ + class ngraph::pass::SwishFusionWithSigmoid: public ngraph::pass::MatcherPass { +public: + SwishFusionWithSigmoid(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op. + */ +class ngraph::pass::SwishFusionWithSigmoidWithBeta: public ngraph::pass::MatcherPass { +public: + SwishFusionWithSigmoidWithBeta(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op. + */ +class ngraph::pass::SwishFusionWithBeta: public ngraph::pass::MatcherPass { +public: + SwishFusionWithBeta(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op. + */ +class ngraph::pass::SwishFusionWithoutBeta: public ngraph::pass::MatcherPass { +public: + SwishFusionWithoutBeta(); +}; diff --git a/inference-engine/src/transformations/src/ngraph_ops/swish_ie.cpp b/inference-engine/src/transformations/src/ngraph_ops/swish_ie.cpp new file mode 100644 index 0000000..cd04251 --- /dev/null +++ b/inference-engine/src/transformations/src/ngraph_ops/swish_ie.cpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph_ops/swish_ie.hpp" + +#include +#include + +#include "ngraph/util.hpp" +#include "ngraph/validation_util.hpp" + +using namespace std; +using namespace ngraph; + +constexpr NodeTypeInfo op::SwishIE::type_info; + +op::SwishIE::SwishIE(const Output & input, const float alpha) + : Op({input}), m_alpha(alpha) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr op::SwishIE::clone_with_new_inputs(const OutputVector& new_args) const { + check_new_args_count(this, new_args); + return make_shared(new_args.at(0), m_alpha); +} + +bool op::SwishIE::visit_attributes(AttributeVisitor& visitor) { + visitor.on_attribute("alpha", m_alpha); + return true; +} + +void op::SwishIE::validate_and_infer_types() { + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +} + +void op::SwishIE::set_alpha(float alpha) { + m_alpha = alpha; +} + +float op::SwishIE::get_alpha() const { + return m_alpha; +} + diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 78cee29..eb5ee57 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -12,6 +12,7 @@ #include "transformations/init_node_info.hpp" #include "transformations/itt.hpp" #include "transformations/mish_fusion.hpp" +#include "transformations/swish_fusion.hpp" #include #include @@ -34,6 +35,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); // partially depends on CF manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.set_callback(m_transformation_callback); manager.run_passes(f); diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp index 12a7766..c8226c8 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -129,6 +130,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptradd_matcher(); anchor->add_matcher(); anchor->add_matcher(); + anchor->add_matcher(); anchor->add_matcher()->detect_output_type(f); anchor->add_matcher(); anchor->add_matcher(); diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.cpp new file mode 100644 index 0000000..1a3f538 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp" + +#include + +#include + +#include +#include +#include +#include + +ngraph::pass::ConvertSwishToSwishIEMatcher::ConvertSwishToSwishIEMatcher() { + auto swish = ngraph::pattern::wrap_type(); + + ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { + auto swish = std::dynamic_pointer_cast (m.get_match_root()); + if (!swish) { + return false; + } + float beta_value = 1.0; + if (swish->input_values().size() == 2) { + auto beta_node = swish->input_value(1).get_node_shared_ptr(); + auto beta_const = std::dynamic_pointer_cast(beta_node); + + if (!beta_const) { + return false; + } + if (!ngraph::op::util::get_single_value(beta_const, beta_value)) { + return false; + } + } + + auto swish_ie = std::make_shared(swish->input(0).get_source_output(), beta_value); + swish_ie->set_friendly_name(swish->get_friendly_name()); + ngraph::copy_runtime_info(swish, swish_ie); + ngraph::replace_node(swish, swish_ie); + return true; + }; + + auto m = std::make_shared(swish, "ConvertSwishToSwishIE"); + this->register_matcher(m, callback); +} \ No newline at end of file diff --git a/inference-engine/src/transformations/src/transformations/swish_fusion.cpp b/inference-engine/src/transformations/src/transformations/swish_fusion.cpp new file mode 100644 index 0000000..dcad9d7 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/swish_fusion.cpp @@ -0,0 +1,183 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/swish_fusion.hpp" + +#include + +#include +#include +#include + +bool check_constant_value(const std::shared_ptr& constant) { + if (!constant) { + return false; + } + if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) { + auto data = constant->cast_vector(); + if (data.size() != 1 || data[0] != 1.0) { + return false; + } + } else { + return false; + } + return true; +} + +bool check_beta_value(const std::shared_ptr& constant) { + // check that the constant for beta contains only one distinct element + if (!constant) { + return false; + } + if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) { + auto data = constant->cast_vector(); + if (!std::equal(data.begin() + 1, data.end(), data.begin())) { + return false; + } + } else { + return false; + } + return true; +} + +ngraph::pass::SwishFusionWithSigmoid::SwishFusionWithSigmoid() { + // replaces a sub-graphs x * Sigmoid(x) with a Swish op. + auto input = ngraph::pattern::any_input(); + auto sigmoid = std::make_shared(input); + auto mul = std::make_shared(input, sigmoid); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { + auto &pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + + auto swish = std::make_shared(exp_input); + + swish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr()}, + swish); + ngraph::replace_node(m.get_match_root(), swish); + return true; + }; + + auto m = std::make_shared(mul, "SwishWithSigmoidFusion"); + register_matcher(m, callback); +} + +ngraph::pass::SwishFusionWithSigmoidWithBeta::SwishFusionWithSigmoidWithBeta() { + // replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op. + auto input = ngraph::pattern::any_input(); + auto beta = ngraph::pattern::any_input(); + auto mul_beta = std::make_shared(input, beta); + auto sigmoid = std::make_shared(mul_beta); + auto mul = std::make_shared(input, sigmoid); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { + auto &pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + auto beta_input = pattern_to_output.at(beta); + + auto beta_constant = std::dynamic_pointer_cast(beta_input.get_node_shared_ptr()); + Output new_beta; + if (beta_constant) { + if (check_beta_value(beta_constant)) { + new_beta = opset4::Constant::create(beta_input.get_element_type(), Shape{}, {beta_constant->cast_vector()[0]}); + } else { + return false; + } + } else { + // if the input is not constant and number of elements is not equal to 1 then we cannot perform fusing + if (beta_input.get_partial_shape().is_dynamic() || ngraph::shape_size(beta_input.get_shape()) != 1) { + return false; + } + new_beta = beta_input; + } + + auto swish = std::make_shared(exp_input, new_beta); + + swish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr()}, + swish); + ngraph::replace_node(m.get_match_root(), swish); + return true; + }; + + auto m = std::make_shared(mul, "SwishWithSigmoidWithBetaFusion"); + register_matcher(m, callback); +} + +ngraph::pass::SwishFusionWithBeta::SwishFusionWithBeta() { + // replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op. + auto input = ngraph::pattern::any_input(); + auto beta = ngraph::pattern::any_input(); + auto mul = std::make_shared(input, beta); + auto neg = std::make_shared(mul); + auto exp = std::make_shared(neg); + auto add_constant = ngraph::pattern::wrap_type(); + auto add = std::make_shared(exp, add_constant); + auto div = std::make_shared(input, add); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { + auto &pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + + auto constant = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); + if (!check_constant_value(constant)) { + return false; + } + + auto swish = std::make_shared(exp_input, pattern_to_output.at(beta)); + + swish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(beta).get_node_shared_ptr(), + pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(neg).get_node_shared_ptr(), + pattern_to_output.at(exp).get_node_shared_ptr(), + pattern_to_output.at(add_constant).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(div).get_node_shared_ptr()}, + swish); + ngraph::replace_node(m.get_match_root(), swish); + return true; + }; + + auto m = std::make_shared(div, "SwishWithBetaFusion"); + register_matcher(m, callback); +} + +ngraph::pass::SwishFusionWithoutBeta::SwishFusionWithoutBeta() { + // replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op. + auto input = ngraph::pattern::any_input(); + auto neg = std::make_shared(input); + auto exp = std::make_shared(neg); + auto add_constant = ngraph::pattern::wrap_type(); + auto add = std::make_shared(exp, add_constant); + auto div = std::make_shared(input, add); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + + auto constant = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); + if (!check_constant_value(constant)) { + return false; + } + + auto swish = std::make_shared(exp_input); + + swish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(neg).get_node_shared_ptr(), + pattern_to_output.at(exp).get_node_shared_ptr(), + pattern_to_output.at(add_constant).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(div).get_node_shared_ptr()}, + swish); + ngraph::replace_node(m.get_match_root(), swish); + return true; + }; + + auto m = std::make_shared(div, "SwishWithoutBetaFusion"); + register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp index a540cec..b349378 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp @@ -6,12 +6,10 @@ #include #include -#include #include #include #include -#include #include #include #include diff --git a/inference-engine/tests/functional/inference_engine/transformations/swish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/swish_fusion_test.cpp new file mode 100644 index 0000000..e6125ae --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/swish_fusion_test.cpp @@ -0,0 +1,206 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, SwishFusionWithBeta) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic(1)); + auto beta = std::make_shared(ngraph::element::f32, ngraph::Shape{}); + auto mul = std::make_shared(input, beta); + auto neg = std::make_shared(mul); + auto exp = std::make_shared(neg); + auto constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0}); + auto add = std::make_shared(exp, constant); + auto div = std::make_shared(input, add); + + f = std::make_shared(ngraph::NodeVector{div}, ngraph::ParameterVector{input, beta}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic(1)); + auto beta = std::make_shared(ngraph::element::f32, ngraph::Shape{}); + auto swish = std::make_shared(input, beta); + + f_ref = std::make_shared(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SwishFusionWithoutBeta) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto neg = std::make_shared(input); + auto exp = std::make_shared(neg); + auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0}); + auto add = std::make_shared(exp, constant); + auto div = std::make_shared(input, add); + + f = std::make_shared(ngraph::NodeVector{div}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto swish = std::make_shared(input); + + f_ref = std::make_shared(ngraph::NodeVector{swish}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SwishFusionWithoutBetaNonOneAddConstant) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto neg = std::make_shared(input); + auto exp = std::make_shared(neg); + auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1}); + auto add = std::make_shared(exp, constant); + auto div = std::make_shared(input, add); + + f = std::make_shared(ngraph::NodeVector{div}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto neg = std::make_shared(input); + auto exp = std::make_shared(neg); + auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1}); + auto add = std::make_shared(exp, constant); + auto div = std::make_shared(input, add); + + f = std::make_shared(ngraph::NodeVector{div}, ngraph::ParameterVector{input}); + + f_ref = std::make_shared(ngraph::NodeVector{div}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SwishFusionWithSigmoid) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto sig = std::make_shared(input); + auto mul = std::make_shared(input, sig); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto swish = std::make_shared(input); + + f_ref = std::make_shared(ngraph::NodeVector{swish}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SwishFusionWithSigmoidWithBeta) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto beta = std::make_shared(ngraph::element::f16, ngraph::Shape{}); + auto mul_beta = std::make_shared(input, beta); + auto sig = std::make_shared(mul_beta); + auto mul = std::make_shared(input, sig); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, beta}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto beta = std::make_shared(ngraph::element::f16, ngraph::Shape{}); + auto swish = std::make_shared(input, beta); + + f_ref = std::make_shared(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SwishFusionWithSigmoidWithBetaConstant) { + // test where the beta constant has multiple but the same value + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{3}, {2.0, 2.0, 2.0}); + auto mul_beta = std::make_shared(input, beta); + auto sig = std::make_shared(mul_beta); + auto mul = std::make_shared(input, sig); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {2.0}); + auto swish = std::make_shared(input, beta); + + f_ref = std::make_shared(ngraph::NodeVector{swish}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index e20c667..5674b91 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -444,7 +444,7 @@ extensions/front/tf/ssd_toolbox_multihead_detection_output.json extensions/front/tf/ssd_v2_support.json extensions/front/tf/SSDToolboxDetectionOutput.py extensions/front/tf/swap_deconv_inputs.py -extensions/front/tf/swish.py +extensions/front/tf/swish_ext.py extensions/front/tf/SwitchMergeOptimization.py extensions/front/tf/TensorArrayExtractors.py extensions/front/tf/TensorArrayGatherV3.py diff --git a/model-optimizer/extensions/front/tf/swish.py b/model-optimizer/extensions/front/tf/swish.py deleted file mode 100644 index a77b463..0000000 --- a/model-optimizer/extensions/front/tf/swish.py +++ /dev/null @@ -1,37 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from extensions.ops.activation_ops import Sigmoid -from extensions.ops.elementwise import Mul -from mo.front.common.replacement import FrontReplacementOp -from mo.graph.graph import Node, Graph - - -class Swish(FrontReplacementOp): - op = "swish_f32" - enabled = True - - def replace_op(self, graph: Graph, node: Node): - mul_node = Mul(graph, {'name': node.name + '/mul_'}).create_node() - sigmoid_node = Sigmoid(graph, {'name': node.name + '/sigmoid_'}).create_node() - - # Connect nodes - node.in_port(0).get_connection().get_source().connect(mul_node.in_port(0)) - node.in_port(0).get_connection().get_source().connect(sigmoid_node.in_port(0)) - sigmoid_node.out_port(0).connect(mul_node.in_port(1)) - - # The "explicit" version of the return value is: [(out_node.id, 0)]) - return [mul_node.id] diff --git a/model-optimizer/extensions/front/tf/swish_ext.py b/model-optimizer/extensions/front/tf/swish_ext.py new file mode 100644 index 0000000..9700877 --- /dev/null +++ b/model-optimizer/extensions/front/tf/swish_ext.py @@ -0,0 +1,29 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from extensions.ops.activation_ops import Swish +from mo.front.extractor import FrontExtractorOp +from mo.graph.graph import Node + + +class SwishExtractor(FrontExtractorOp): + op = 'swish_f32' + enabled = True + + @classmethod + def extract(cls, node: Node): + Swish.update_node_stat(node, {}) + return cls.enabled + diff --git a/model-optimizer/extensions/front/tf/swish_test.py b/model-optimizer/extensions/front/tf/swish_test.py deleted file mode 100644 index 211e042..0000000 --- a/model-optimizer/extensions/front/tf/swish_test.py +++ /dev/null @@ -1,57 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import unittest - -import numpy as np - -from extensions.front.tf.swish import Swish -from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph - -nodes_attributes = { - 'placeholder_1': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'placeholder_2': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - # swish operation - 'swish': {'kind': 'op', 'op': 'swish_f32'}, - # Test operation - 'last': {'type': None, 'value': None, 'kind': 'op', 'op': None}, - # Add and Mul operations - 'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'}, - 'sigmoid': {'value': None, 'type': 'Sigmoid', 'kind': 'op', 'op': 'Sigmoid'}, -} - - -class TestSwish(unittest.TestCase): - def test_swish_test_1(self): - # Test with two different inputs from two placeholders - graph = build_graph(nodes_attributes, - [('placeholder_1', 'swish'), - ('swish', 'last') - ], nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [('placeholder_1', 'sigmoid', {'out': 0}), - ('placeholder_1', 'mul', {'in': 0, 'out': 0}), - ('sigmoid', 'mul', {'in': 1}), - ('mul', 'last'), - ], nodes_with_edges_only=True) - - graph.stage = 'front' - Swish().find_and_replace_pattern(graph) - - (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) - self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/ops/activation_ops.py b/model-optimizer/extensions/ops/activation_ops.py index ebcd3a7..a05dba3 100644 --- a/model-optimizer/extensions/ops/activation_ops.py +++ b/model-optimizer/extensions/ops/activation_ops.py @@ -198,9 +198,9 @@ class LeakyReLU(Op): def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { - 'type': __class__.op, - 'op': __class__.op, - 'infer': __class__.infer, + 'type': self.op, + 'op': self.op, + 'infer': self.infer, 'in_ports_count': 1, 'out_ports_count': 1, }, attrs) @@ -265,3 +265,36 @@ class Mish(Activation): sp_attrs = {'version': 'opset4'} sp_attrs.update(attrs) super().__init__(graph, sp_attrs) + + +class Swish(Op): + op = 'Swish' + + def __init__(self, graph: Graph, attrs: dict): + mandatory_props = { + 'op': self.op, + 'type': self.op, + 'version': 'opset4', + + 'infer': self.infer, + + 'in_ports_count': 2, + 'out_ports_count': 1, + } + super().__init__(graph, mandatory_props, attrs) + + @staticmethod + def infer(node: Node): + node_name = node.soft_get('name', node.id) + node.out_port(0).data.set_shape(node.in_port(0).data.get_shape()) + + beta = 1.0 + if node.is_in_port_connected(1): + beta = node.in_port(1).data.get_value() + if beta is not None: + assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name) + beta = beta.item() + + input_value = node.in_port(1).data.get_value() + if input_value is not None and beta is not None: + node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta))) diff --git a/ngraph/python/src/ngraph/__init__.py b/ngraph/python/src/ngraph/__init__.py index 08cf718..b3cae2d 100644 --- a/ngraph/python/src/ngraph/__init__.py +++ b/ngraph/python/src/ngraph/__init__.py @@ -150,6 +150,7 @@ from ngraph.opset4 import squared_difference from ngraph.opset4 import squeeze from ngraph.opset4 import strided_slice from ngraph.opset4 import subtract +from ngraph.opset4 import swish from ngraph.opset4 import tan from ngraph.opset4 import tanh from ngraph.opset4 import tensor_iterator diff --git a/ngraph/python/src/ngraph/opset4/__init__.py b/ngraph/python/src/ngraph/opset4/__init__.py index eac33dd..8dbbf16 100644 --- a/ngraph/python/src/ngraph/opset4/__init__.py +++ b/ngraph/python/src/ngraph/opset4/__init__.py @@ -139,6 +139,7 @@ from ngraph.opset1.ops import squared_difference from ngraph.opset1.ops import squeeze from ngraph.opset1.ops import strided_slice from ngraph.opset1.ops import subtract +from ngraph.opset4.ops import swish from ngraph.opset1.ops import tan from ngraph.opset1.ops import tanh from ngraph.opset1.ops import tensor_iterator diff --git a/ngraph/python/src/ngraph/opset4/ops.py b/ngraph/python/src/ngraph/opset4/ops.py index b91f4e7..1366360 100644 --- a/ngraph/python/src/ngraph/opset4/ops.py +++ b/ngraph/python/src/ngraph/opset4/ops.py @@ -147,3 +147,19 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node: :return: The new node which performs Mish """ return _get_node_factory_opset4().create("Mish", as_nodes(data), {}) + + +@nameable_op +def swish( + data: NodeInput, + beta: Optional[NodeInput] = None, + name: Optional[str] = None, +) -> Node: + """Return a node which performing Swish activation function Swish(x, beta=1.0) = x * sigmoid(x * beta)). + + :param data: Tensor with input data floating point type. + :return: The new node which performs Swish + """ + if beta is None: + beta = make_constant_node(1.0, np.float32) + return _get_node_factory_opset4().create("Swish", as_nodes(data, beta), {}) diff --git a/ngraph/python/tests/test_ngraph/test_swish.py b/ngraph/python/tests/test_ngraph/test_swish.py new file mode 100644 index 0000000..e4917e8 --- /dev/null +++ b/ngraph/python/tests/test_ngraph/test_swish.py @@ -0,0 +1,41 @@ +# ****************************************************************************** +# Copyright 2017-2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ****************************************************************************** +import numpy as np +import ngraph as ng +from ngraph.impl import Shape, Type + + +def test_swish_props_with_beta(): + float_dtype = np.float32 + data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data") + beta = ng.parameter(Shape([]), dtype=float_dtype, name="beta") + + node = ng.swish(data, beta) + assert node.get_type_name() == "Swish" + assert node.get_output_size() == 1 + assert list(node.get_output_shape(0)) == [3, 10] + assert node.get_output_element_type(0) == Type.f32 + + +def test_swish_props_without_beta(): + float_dtype = np.float32 + data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data") + + node = ng.swish(data) + assert node.get_type_name() == "Swish" + assert node.get_output_size() == 1 + assert list(node.get_output_shape(0)) == [3, 10] + assert node.get_output_element_type(0) == Type.f32 diff --git a/ngraph/src/ngraph/CMakeLists.txt b/ngraph/src/ngraph/CMakeLists.txt index 9bd860c..cd6f00a 100644 --- a/ngraph/src/ngraph/CMakeLists.txt +++ b/ngraph/src/ngraph/CMakeLists.txt @@ -332,6 +332,8 @@ set (SRC op/subtract.hpp op/sum.cpp op/sum.hpp + op/swish.cpp + op/swish.hpp op/variadic_split.cpp op/variadic_split.hpp op/tan.cpp diff --git a/ngraph/src/ngraph/op/swish.cpp b/ngraph/src/ngraph/op/swish.cpp new file mode 100644 index 0000000..e1a8347 --- /dev/null +++ b/ngraph/src/ngraph/op/swish.cpp @@ -0,0 +1,140 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "ngraph/op/swish.hpp" +#include "ngraph/attribute_visitor.hpp" +#include "ngraph/op/constant.hpp" + +#include "ngraph/runtime/host_tensor.hpp" +#include "ngraph/runtime/reference/swish.hpp" + +using namespace std; +using namespace ngraph; + +constexpr NodeTypeInfo op::v4::Swish::type_info; + +op::v4::Swish::Swish(const Output& arg) + : Op({arg}) +{ + constructor_validate_and_infer_types(); +} + +op::v4::Swish::Swish(const Output& arg, const Output& beta) + : Op({arg, beta}) +{ + constructor_validate_and_infer_types(); +} + +bool op::v4::Swish::visit_attributes(AttributeVisitor& visitor) +{ + return true; +} + +void op::v4::Swish::validate_and_infer_types() +{ + auto inputs_count = input_values().size(); + NODE_VALIDATION_CHECK(this, + inputs_count == 1 || inputs_count == 2, + "Swish must have 1 or 2 inputs, but it has: ", + inputs_count); + + if (inputs_count == 2) + { + NODE_VALIDATION_CHECK(this, + input_value(0).get_element_type() == + input_value(1).get_element_type(), + "Swish inputs must have the same type but they are: ", + input_value(0).get_element_type(), + " and ", + input_value(1).get_element_type()); + if (get_input_partial_shape(1).rank().is_static()) + { + auto beta_rank = get_input_partial_shape(1).rank().get_length(); + NODE_VALIDATION_CHECK(this, + beta_rank == 0, + "Swish input with beta must be scalar but it has rank: ", + beta_rank); + } + } + set_output_size(1); + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +} + +shared_ptr op::v4::Swish::clone_with_new_inputs(const OutputVector& new_args) const +{ + if (new_args.size() == 1) + { + return make_shared(new_args.at(0)); + } + else + { + return make_shared(new_args.at(0), new_args.at(1)); + } +} + +namespace +{ + template + inline bool evaluate(const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const HostTensorPtr& out, + const size_t count) + { + using T = typename element_type_traits::value_type; + if (arg1 != nullptr) + { + runtime::reference::swish( + arg0->get_data_ptr(), arg1->get_data_ptr(), out->get_data_ptr(), count); + } + else + { + runtime::reference::swish( + arg0->get_data_ptr(), nullptr, out->get_data_ptr(), count); + } + return true; + } + + bool evaluate_swish(const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const HostTensorPtr& out, + const size_t count) + { + bool rc = true; + out->set_unary(arg0); + + switch (arg0->get_element_type()) + { + TYPE_CASE(f16)(arg0, arg1, out, count); + break; + TYPE_CASE(f32)(arg0, arg1, out, count); + break; + default: rc = false; break; + } + return rc; + } +} + +bool op::v4::Swish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const +{ + if (inputs.size() == 2) + { + return evaluate_swish(inputs[0], inputs[1], outputs[0], shape_size(get_output_shape(0))); + } + else + { + return evaluate_swish(inputs[0], nullptr, outputs[0], shape_size(get_output_shape(0))); + } +} diff --git a/ngraph/src/ngraph/op/swish.hpp b/ngraph/src/ngraph/op/swish.hpp new file mode 100644 index 0000000..1c0b6ed --- /dev/null +++ b/ngraph/src/ngraph/op/swish.hpp @@ -0,0 +1,57 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include "ngraph/node.hpp" +#include "ngraph/op/op.hpp" + +namespace ngraph +{ + namespace op + { + namespace v4 + { + /// \brief A Swish Activation Function + /// f(x) = x / (1.0 + exp(-beta * x)) or + /// f(x) = x * sigmoid(beta * x) + /// + class NGRAPH_API Swish : public ngraph::op::Op + { + public: + static constexpr NodeTypeInfo type_info{"Swish", 4}; + const NodeTypeInfo& get_type_info() const override { return type_info; } + Swish() = default; + + /// \brief Constructs an Swish operation. + /// + /// \param data Input tensor + /// \param beta Scalar with beta value. If the argument is not specified then use + /// the default value 1.0 + Swish(const Output& arg, const Output& beta); + explicit Swish(const Output& arg); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + + virtual std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + bool evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const override; + }; + } + } +} diff --git a/ngraph/src/ngraph/ops.hpp b/ngraph/src/ngraph/ops.hpp index 8cb45e3..ca5d940 100644 --- a/ngraph/src/ngraph/ops.hpp +++ b/ngraph/src/ngraph/ops.hpp @@ -157,6 +157,7 @@ #include "ngraph/op/strided_slice.hpp" #include "ngraph/op/subtract.hpp" #include "ngraph/op/sum.hpp" +#include "ngraph/op/swish.hpp" #include "ngraph/op/tan.hpp" #include "ngraph/op/tanh.hpp" #include "ngraph/op/tensor_iterator.hpp" diff --git a/ngraph/src/ngraph/opsets/opset4_tbl.hpp b/ngraph/src/ngraph/opsets/opset4_tbl.hpp index 975567b..4b0dd22 100644 --- a/ngraph/src/ngraph/opsets/opset4_tbl.hpp +++ b/ngraph/src/ngraph/opsets/opset4_tbl.hpp @@ -155,6 +155,7 @@ NGRAPH_OP(TopK, ngraph::op::v3) NGRAPH_OP(Acosh, ngraph::op::v3) NGRAPH_OP(Asinh, ngraph::op::v3) NGRAPH_OP(Atanh, ngraph::op::v3) +NGRAPH_OP(CTCLoss, ngraph::op::v4) NGRAPH_OP(NonMaxSuppression, ngraph::op::v4) NGRAPH_OP(Mish, ngraph::op::v4) -NGRAPH_OP(CTCLoss, ngraph::op::v4) +NGRAPH_OP(Swish, ngraph::op::v4) diff --git a/ngraph/src/ngraph/runtime/reference/swish.hpp b/ngraph/src/ngraph/runtime/reference/swish.hpp new file mode 100644 index 0000000..14bb42e --- /dev/null +++ b/ngraph/src/ngraph/runtime/reference/swish.hpp @@ -0,0 +1,43 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void swish(const T* arg, const T* beta, T* out, size_t count) + { + T beta_value = static_cast(1.0); + if (beta != nullptr) + { + beta_value = beta[0]; + } + for (size_t i = 0; i < count; i++) + { + out[i] = arg[i] / (1.0 + std::exp(-arg[i] * beta_value)); + } + } + } + } +} diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 3730265..2447040 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -77,6 +77,7 @@ set(SRC op_eval/non_zero.cpp op_eval/split.cpp op_eval/strided_slice.cpp + op_eval/swish.cpp op_is.cpp opset1.cpp partial_shape.cpp @@ -165,6 +166,7 @@ set(SRC type_prop/squared_difference.cpp type_prop/squeeze.cpp type_prop/sum.cpp + type_prop/swish.cpp type_prop/reduce_prod.cpp type_prop/reduce_sum.cpp type_prop/tile.cpp diff --git a/ngraph/test/op_eval/swish.cpp b/ngraph/test/op_eval/swish.cpp new file mode 100644 index 0000000..26997df --- /dev/null +++ b/ngraph/test/op_eval/swish.cpp @@ -0,0 +1,90 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "gtest/gtest.h" + +#include "ngraph/op/swish.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "ngraph/validation_util.hpp" +#include "runtime/backend.hpp" +#include "util/test_tools.hpp" + +using namespace std; +using namespace ngraph; + +TEST(op_eval, swish_with_beta1) +{ + auto p = make_shared(element::f32, Shape{3}); + auto beta = make_shared(element::f32, Shape{}); + auto swish = make_shared(p, beta); + auto fun = make_shared(OutputVector{swish}, ParameterVector{p, beta}); + + std::vector inputs{-0.5, 0.0, 0.5}; + std::vector expected_result{-0.18877034, 0.0, 0.31122968}; + + auto result = make_shared(); + ASSERT_TRUE(fun->evaluate({result}, + {make_host_tensor(Shape{3}, inputs), + make_host_tensor(Shape{}, {1.0})})); + EXPECT_EQ(result->get_element_type(), element::f32); + EXPECT_EQ(result->get_shape(), Shape{3}); + auto result_data = read_vector(result); + for (auto i = 0; i < inputs.size(); i++) + EXPECT_NEAR(result_data[i], expected_result[i], 0.000001); +} + +TEST(op_eval, swish_with_beta0_75) +{ + auto p = make_shared(element::f32, Shape{3}); + auto beta = make_shared(element::f32, Shape{}); + auto swish = make_shared(p, beta); + auto fun = make_shared(OutputVector{swish}, ParameterVector{p, beta}); + + std::vector inputs{-0.5, 0.0, 0.5}; + std::vector expected_result{-0.2036667, 0.0, 0.2963333}; + + auto result = make_shared(); + ASSERT_TRUE(fun->evaluate({result}, + {make_host_tensor(Shape{3}, inputs), + make_host_tensor(Shape{}, {0.75})})); + EXPECT_EQ(result->get_element_type(), element::f32); + EXPECT_EQ(result->get_shape(), Shape{3}); + auto result_data = read_vector(result); + for (auto i = 0; i < inputs.size(); i++) + EXPECT_NEAR(result_data[i], expected_result[i], 0.000001); +} + +TEST(op_eval, swish_without_beta) +{ + auto p = make_shared(element::f32, Shape{3}); + auto swish = make_shared(p); + auto fun = make_shared(OutputVector{swish}, ParameterVector{p}); + + std::vector inputs{-0.5, 0.0, 0.5}; + std::vector expected_result{-0.18877034, 0.0, 0.31122968}; + + auto result = make_shared(); + ASSERT_TRUE( + fun->evaluate({result}, {make_host_tensor(Shape{3}, inputs)})); + EXPECT_EQ(result->get_element_type(), element::f32); + EXPECT_EQ(result->get_shape(), Shape{3}); + auto result_data = read_vector(result); + for (auto i = 0; i < inputs.size(); i++) + EXPECT_NEAR(result_data[i], expected_result[i], 0.000001); +} diff --git a/ngraph/test/type_prop/swish.cpp b/ngraph/test/type_prop/swish.cpp new file mode 100644 index 0000000..6611009 --- /dev/null +++ b/ngraph/test/type_prop/swish.cpp @@ -0,0 +1,95 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "util/type_prop.hpp" + +using namespace std; +using namespace ngraph; + +TEST(type_prop, swish) +{ + auto data = make_shared(element::f32, Shape{1, 3, 6}); + auto swish_func = make_shared(data); + EXPECT_EQ(swish_func->get_element_type(), element::f32); + EXPECT_EQ(swish_func->get_shape(), data->get_output_shape(0)); +} + +TEST(type_prop, swish_partial) +{ + auto data = make_shared(element::f32, PartialShape{1, Dimension::dynamic(), 6}); + auto swish_func = make_shared(data); + EXPECT_EQ(swish_func->get_element_type(), element::f32); + ASSERT_TRUE( + swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0))); + + // rank unknown + auto swish_partial = make_shared( + make_shared(element::f32, PartialShape::dynamic())); + ASSERT_TRUE(swish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); +} + +TEST(type_prop, swish_partial_static_rank) +{ + auto data = make_shared(element::f32, PartialShape{1, Dimension::dynamic(), 6}); + auto swish_func = make_shared(data); + EXPECT_EQ(swish_func->get_element_type(), element::f32); + ASSERT_TRUE( + swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0))); + ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static()); +} + +TEST(type_prop, swish_incompatible_types) +{ + auto data = make_shared(element::f32, Shape{1, 3, 6}); + auto beta = make_shared(element::f16, Shape{}); + try + { + const auto swish_func = make_shared(data, beta); + FAIL() << "swish_func node was created with incompatible input data types."; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish inputs must have the same type")); + } +} + +TEST(type_prop, swish_beta_not_scalar) +{ + auto data = make_shared(element::f32, Shape{1, 3, 6}); + auto beta = make_shared(element::f32, Shape{1}); + try + { + const auto swish_func = make_shared(data, beta); + FAIL() << "swish_func node was created with scalar beta value."; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish input with beta must be scalar")); + } +} + +TEST(type_prop, swish_2_inputs) +{ + auto data = make_shared(element::f32, Shape{1, 3, 6}); + auto beta = make_shared(element::f32, Shape{}); + const auto swish_func = make_shared(data, beta); + + EXPECT_EQ(swish_func->get_element_type(), element::f32); + ASSERT_TRUE(swish_func->get_output_partial_shape(0).same_scheme(data->get_output_shape(0))); + ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static()); +}