From a2f0eef6aa5dc94bb9a26425fbe57eff6de26f18 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 25 Aug 2020 10:19:06 +0300 Subject: [PATCH] [CPU] Added H-Swish activation (#1445) --- .../src/mkldnn_plugin/mkldnn_graph_optimizer.cpp | 12 ++-- inference-engine/src/mkldnn_plugin/mkldnn_node.cpp | 1 + .../src/mkldnn_plugin/mkldnn_plugin.cpp | 1 + .../mkldnn_plugin/nodes/mkldnn_activation_node.cpp | 5 ++ .../include/transformations/hswish_fusion.hpp | 11 ++++ .../src/transformations/hswish_fusion.cpp | 43 ++++++++++++++ .../transformations/hswish_fusion_test.cpp | 67 ++++++++++++++++++++++ inference-engine/thirdparty/mkl-dnn | 2 +- 8 files changed, 135 insertions(+), 7 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index d768902..31f0145 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -705,7 +705,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndActivation(MKLDNNGraph &graph) { (activationNode->getAlgorithm() == eltwise_relu || (conv->getCnnLayer()->precision == Precision::FP32 && isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp, - eltwise_swish, eltwise_mish}))); + eltwise_swish, eltwise_hswish, eltwise_mish}))); }; for (int i = 0; i < graphNodes.size(); i++) { @@ -1188,7 +1188,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndSimpleOperation(MKLDNNGraph &graph) THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName(); return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu, - eltwise_clamp, eltwise_swish, eltwise_mish}); + eltwise_clamp, eltwise_swish, eltwise_hswish, eltwise_mish}); } return false; @@ -1433,7 +1433,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG (activationNode->getAlgorithm() == eltwise_relu || (conv->getCnnLayer()->precision == Precision::FP32 && isOneOf(activationNode->getAlgorithm(), {eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp, - eltwise_swish, eltwise_mish}))); + eltwise_swish, eltwise_hswish, eltwise_mish}))); #else return false; #endif @@ -1783,8 +1783,8 @@ void MKLDNNGraphOptimizer::FuseNormalizeAndSimpleOperation(MKLDNNGraph &graph) { if (activationNode == nullptr) THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName(); return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic, - eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_mish, eltwise_linear, eltwise_abs, - eltwise_square, eltwise_sqrt}); + eltwise_bounded_relu, eltwise_clamp, eltwise_tanh, eltwise_swish, eltwise_hswish, eltwise_mish, eltwise_linear, + eltwise_abs, eltwise_square, eltwise_sqrt}); } return false; }; @@ -1895,7 +1895,7 @@ void MKLDNNGraphOptimizer::FuseEltwiseAndSimple(MKLDNNGraph &graph) { if (activationNode == nullptr) THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName(); return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu, - eltwise_clamp, eltwise_swish, eltwise_mish}); + eltwise_clamp, eltwise_swish, eltwise_hswish, eltwise_mish}); } return false; diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index 12d225f..bf1ec81 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -75,6 +75,7 @@ static const InferenceEngine::details::caseless_unordered_map { "Activation", Activation }, { "Clamp", Activation }, { "Swish", Activation }, + { "HSwish", Activation }, { "Mish", Activation }, { "ScaleShift", Depthwise }, { "PReLU", Depthwise }, diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index 707058d..0d36510 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -81,6 +81,7 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) { std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node) || + std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node) || std::dynamic_pointer_cast(node); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_activation_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_activation_node.cpp index be4d81a..8727b6e 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_activation_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_activation_node.cpp @@ -96,6 +96,11 @@ caseless_map(); add_matcher(); add_matcher(); + add_matcher(); } }; @@ -61,3 +63,12 @@ public: public: HSwishFusionWithoutRelu(); }; + +/** + * @ingroup ie_transformation_common_api + * @brief HSwishFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op. + */ +class ngraph::pass::HSwishFusionWithClamp: public ngraph::pass::MatcherPass { +public: + HSwishFusionWithClamp(); +}; diff --git a/inference-engine/src/transformations/src/transformations/hswish_fusion.cpp b/inference-engine/src/transformations/src/transformations/hswish_fusion.cpp index 60af25e..bedc4f7 100644 --- a/inference-engine/src/transformations/src/transformations/hswish_fusion.cpp +++ b/inference-engine/src/transformations/src/transformations/hswish_fusion.cpp @@ -178,3 +178,46 @@ ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() { auto m = std::make_shared(mul, "HSwishWithoutReluFusion"); register_matcher(m, callback); } + +ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() { + // Replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSwish op. + auto input = ngraph::pattern::any_input(); + auto add_constant = ngraph::pattern::wrap_type(); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.0f, 6.0f); + auto mul_constant = ngraph::pattern::wrap_type(); + auto mul_first = std::make_shared(clamp, mul_constant); + auto mul_second = std::make_shared(input, mul_first); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { + auto &pattern_to_output = m.get_pattern_value_map(); + auto x_output = pattern_to_output.at(input); + + auto add_const_value = std::dynamic_pointer_cast(pattern_to_output.at(add_constant).get_node_shared_ptr()); + auto mul_const_value = std::dynamic_pointer_cast(pattern_to_output.at(mul_constant).get_node_shared_ptr()); + + bool valid_constant_values = check_constant_value(add_const_value, 3.0) + && check_constant_value(mul_const_value, (1.0/6.0), 0.0001); + + if (!valid_constant_values) { + return false; + } + + auto hswish = std::make_shared(x_output); + + hswish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(clamp).get_node_shared_ptr(), + pattern_to_output.at(mul_constant).get_node_shared_ptr(), + pattern_to_output.at(mul_first).get_node_shared_ptr(), + pattern_to_output.at(mul_second).get_node_shared_ptr() + }, + hswish); + ngraph::replace_node(m.get_match_root(), hswish); + return true; + }; + + auto m = std::make_shared(mul_second, "HSwishWithClampFusion"); + register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp index abe5053..fb34a24 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/hswish_fusion_test.cpp @@ -151,6 +151,37 @@ TEST(TransformationTests, HSwishFusionWithoutRelu) { ASSERT_TRUE(res.first) << res.second; } +TEST(TransformationTests, HSwishFusionWithClamp) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0}); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.0f, 6.0f); + auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0 / 6.0}); + auto mul_first = std::make_shared(clamp, mul_constant); + auto mul_second = std::make_shared(input, mul_first); + + f = std::make_shared(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto hswish = std::make_shared(input); + + f_ref = std::make_shared(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + TEST(TransformationTests, HSwishFusionWithReluMulWrongConstValue) { std::shared_ptr f(nullptr), f_ref(nullptr); { @@ -272,3 +303,39 @@ TEST(TransformationTests, HSwishFusionWithoutReluWrongConstValue) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; } + +TEST(TransformationTests, HSwishFusionWithClampWrongConstValue) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11}); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.11f, 6.02f); + auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15}); + auto mul_first = std::make_shared(clamp, mul_constant); + auto mul_second = std::make_shared(input, mul_first); + + f = std::make_shared(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f16, ngraph::PartialShape::dynamic(1)); + auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11}); + auto add = std::make_shared(input, add_constant); + auto clamp = std::make_shared(add, 0.11f, 6.02f); + auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15}); + auto mul_first = std::make_shared(clamp, mul_constant); + auto mul_second = std::make_shared(input, mul_first); + + f_ref = std::make_shared(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/inference-engine/thirdparty/mkl-dnn b/inference-engine/thirdparty/mkl-dnn index 1f967a0..eb54063 160000 --- a/inference-engine/thirdparty/mkl-dnn +++ b/inference-engine/thirdparty/mkl-dnn @@ -1 +1 @@ -Subproject commit 1f967a094353b30d65d96a3fe1721d8dccf02278 +Subproject commit eb54063189a33a10c4aa90311788e6fbb4cdf2f6 -- 2.7.4