From 0c1b2f836b7ee7f657dc16e59b171a040c649af1 Mon Sep 17 00:00:00 2001 From: iliya mironov Date: Fri, 4 Sep 2020 11:07:37 +0300 Subject: [PATCH] Add Mish with SoftPlus transformation (#1815) * Add Mish with SoftPlus transformation * Refactoring accrding code review * Add softplus to mish pass registration * Add checks customer count for SoftPlus and Tanh ops --- .../transformations/softplus_to_mish_fusion.hpp | 32 +++++++++++++++++++ .../common_optimizations/common_optimizations.cpp | 2 ++ .../transformations/softplus_to_mish_fusion.cpp | 36 ++++++++++++++++++++++ .../transformations/mish_fusion_test.cpp | 30 ++++++++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 inference-engine/src/transformations/include/transformations/softplus_to_mish_fusion.hpp create mode 100644 inference-engine/src/transformations/src/transformations/softplus_to_mish_fusion.cpp diff --git a/inference-engine/src/transformations/include/transformations/softplus_to_mish_fusion.hpp b/inference-engine/src/transformations/include/transformations/softplus_to_mish_fusion.hpp new file mode 100644 index 0000000..4806cea --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/softplus_to_mish_fusion.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include +#include +#include "ngraph/pattern/matcher.hpp" + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API SoftPlusToMishFusion; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief SoftPlusToMishFusion transformation replaces group of + * operations: x * tanh(softplus(x)) to Mish op. + */ +class ngraph::pass::SoftPlusToMishFusion: public ngraph::pass::MatcherPass { +public: + SoftPlusToMishFusion(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index ff23052..f238328 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -17,6 +17,7 @@ #include "transformations/itt.hpp" #include "transformations/mish_fusion.hpp" #include "transformations/softplus_fusion.hpp" +#include "transformations/softplus_to_mish_fusion.hpp" #include "transformations/swish_fusion.hpp" #include "transformations/hswish_fusion.hpp" #include "transformations/normalize_l2_fusion.hpp" @@ -45,6 +46,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/inference-engine/src/transformations/src/transformations/softplus_to_mish_fusion.cpp b/inference-engine/src/transformations/src/transformations/softplus_to_mish_fusion.cpp new file mode 100644 index 0000000..fca94c5 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/softplus_to_mish_fusion.cpp @@ -0,0 +1,36 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/softplus_to_mish_fusion.hpp" + +#include +#include + +#include +#include +#include + +ngraph::pass::SoftPlusToMishFusion::SoftPlusToMishFusion() { + auto input = ngraph::pattern::any_input(); + auto softplus = ngraph::pattern::wrap_type({input}, pattern::consumers_count(1)); + auto tanh = ngraph::pattern::wrap_type({softplus}, pattern::consumers_count(1)); + auto mul = std::make_shared(input, tanh); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + + auto mish = std::make_shared(exp_input); + + mish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(tanh).get_node_shared_ptr(), + pattern_to_output.at(softplus).get_node_shared_ptr()}, mish); + ngraph::replace_node(m.get_match_root(), mish); + return true; + }; + + auto m = std::make_shared(mul, "SoftPlusToMishFusion"); + register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp index b349378..1921148 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -48,3 +49,32 @@ TEST(TransformationTests, MishFusing) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; } + + +TEST(TransformationTests, MishWithSoftPlusFusing) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input0 = std::make_shared(ngraph::element::f64, ngraph::Shape{3, 1, 2}); + auto softplus = std::make_shared(input0); + auto tanh = std::make_shared(softplus); + auto mul = std::make_shared(input0, tanh); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto mish = std::make_shared(data); + + f_ref = std::make_shared(ngraph::NodeVector{mish}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} -- 2.7.4