From 347e1206d541e01fc66a176a8fc7740cd7e637aa Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Wed, 28 Oct 2020 22:49:12 +0300 Subject: [PATCH] setBatchSize: getting rid of ConstantFolding (#2842) * setBatchSize: getting rid of setBatchSize * Trigger CI * Feedback adressed * Trigger CI * f -> specialized_function --- .../smart_reshape/mimic_set_batch_size.hpp | 20 +--- .../smart_reshape/mimic_set_batch_size.cpp | 132 ++++++--------------- .../smart_reshape/set_batch_size.cpp | 6 - .../smart_reshape/strided_slice_squeeze.cpp | 73 ++++++------ 4 files changed, 79 insertions(+), 152 deletions(-) diff --git a/inference-engine/src/transformations/include/transformations/smart_reshape/mimic_set_batch_size.hpp b/inference-engine/src/transformations/include/transformations/smart_reshape/mimic_set_batch_size.hpp index 18d043e..8a0edab 100644 --- a/inference-engine/src/transformations/include/transformations/smart_reshape/mimic_set_batch_size.hpp +++ b/inference-engine/src/transformations/include/transformations/smart_reshape/mimic_set_batch_size.hpp @@ -23,8 +23,6 @@ namespace ngraph { namespace pass { class TRANSFORMATIONS_API MimicSetBatchSize; -class TRANSFORMATIONS_API DisableCFForPriorBoxes; -class TRANSFORMATIONS_API EnableCFForPriorBoxes; } // namespace pass } // namespace ngraph @@ -41,23 +39,7 @@ class TRANSFORMATIONS_API EnableCFForPriorBoxes; * This transformation should be executed only while setBatchSize method call */ -class ngraph::pass::MimicSetBatchSize: public ngraph::pass::MatcherPass { -public: - NGRAPH_RTTI_DECLARATION; - MimicSetBatchSize(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief DisableCFForPriorBoxes and EnableCFForPriorBoxes transformations are needed to avoid unnecessary PriorBox folding - */ -class ngraph::pass::DisableCFForPriorBoxes: public ngraph::pass::FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - bool run_on_function(std::shared_ptr f) override; -}; - -class ngraph::pass::EnableCFForPriorBoxes: public ngraph::pass::FunctionPass { +class ngraph::pass::MimicSetBatchSize : public ngraph::pass::FunctionPass { public: NGRAPH_RTTI_DECLARATION; bool run_on_function(std::shared_ptr f) override; diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/mimic_set_batch_size.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/mimic_set_batch_size.cpp index 83605ac..3fdba37 100644 --- a/inference-engine/src/transformations/src/transformations/smart_reshape/mimic_set_batch_size.cpp +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/mimic_set_batch_size.cpp @@ -2,35 +2,38 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include NGRAPH_RTTI_DEFINITION(ngraph::pass::MimicSetBatchSize, "MimicSetBatchSize", 0); -ngraph::pass::MimicSetBatchSize::MimicSetBatchSize() { - auto reshape_label = ngraph::pattern::wrap_type({pattern::any_input(pattern::has_static_dim(0)), - ngraph::pattern::wrap_type()}, - [](const Output &output) { return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() > 1; }); - - matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { - const auto & reshape = m.get_match_root(); - auto pattern = std::dynamic_pointer_cast(reshape->get_input_node_shared_ptr(1)); - if (!pattern) - return false; - - const auto & pattern_vector = pattern->cast_vector(); - if (pattern_vector.empty() || pattern_vector[0] < 1) - return false; - - // mimicking old setBatchSize style (copied): - // float diff = static_cast(dims.at(0)) / static_cast(originalBatchSize); - // dims.at(0) = static_cast(std::ceil(size * diff)); - - const auto & old_input_batch = static_cast(reshape->get_input_partial_shape(0)[0].get_length()); - const auto & old_output_batch = static_cast(pattern_vector[0]); - - const auto & scale = old_output_batch / old_input_batch; +bool ngraph::pass::MimicSetBatchSize::run_on_function(std::shared_ptr f) { + // extracting ratio of out to in 0-index dimension value from the folded function + auto specialized_function = ngraph::clone_function(*f); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.run_passes(specialized_function); + + std::map scale; + for (const auto & node : specialized_function->get_ops()) { + if (const auto & reshape = std::dynamic_pointer_cast(node)) { + const auto in_pshape = reshape->get_input_partial_shape(0), out_pshape = reshape->get_output_partial_shape(0); + if (in_pshape.rank().is_dynamic() || in_pshape.rank().get_length() <= 1 || in_pshape[0].is_dynamic() || + out_pshape.rank().is_dynamic() || out_pshape.rank().get_length() <= 1 || out_pshape[0].is_dynamic()) + continue; + const auto & pattern = std::dynamic_pointer_cast(reshape->get_input_node_shared_ptr(1)); + if (pattern && pattern->cast_vector()[0] > 0) { + scale[reshape->get_friendly_name()] = static_cast(out_pshape[0].get_length()) / static_cast(in_pshape[0].get_length()); + } + } + } + // apply transformation to original function + bool transformed = false; + for (auto & reshape : f->get_ops()) { + if (!is_type(reshape) || !scale.count(reshape->get_friendly_name()) || reshape->get_output_partial_shape(0).rank().is_dynamic()) + continue; - const auto & shape_of = std::make_shared(reshape->get_input_source_output(0), pattern->get_element_type()); + const auto & shape_of = std::make_shared(reshape->get_input_source_output(0), reshape->get_input_element_type(1)); const auto & new_input_batch = std::make_shared( shape_of, ngraph::opset5::Constant::create(ngraph::element::i64, {1}, std::vector{0}), ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector{0})); @@ -39,75 +42,18 @@ ngraph::pass::MimicSetBatchSize::MimicSetBatchSize() { std::make_shared( std::make_shared( std::make_shared(new_input_batch, element::f32), - opset5::Constant::create(element::f32, {1}, {scale}))), - pattern->get_element_type()); - - auto new_reshape_pattern = new_output_batch; - const auto rank = pattern_vector.size(); - if (rank > 1) { - std::vector non_batch_dims(rank - 1); - std::iota(non_batch_dims.begin(), non_batch_dims.end(), 1); - const auto & non_batch_dims_node = std::make_shared( - pattern, - ngraph::opset5::Constant::create(ngraph::element::i64, {non_batch_dims.size()}, non_batch_dims), - ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector{0})); - new_reshape_pattern = std::make_shared(OutputVector{new_reshape_pattern, non_batch_dims_node}, 0); - } + opset5::Constant::create(element::f32, {1}, {scale[reshape->get_friendly_name()]}))), + reshape->get_input_element_type(1)); + + std::vector non_batch_dims(reshape->get_output_partial_shape(0).rank().get_length() - 1); + std::iota(non_batch_dims.begin(), non_batch_dims.end(), 1); + const auto & non_batch_dims_node = std::make_shared( + reshape->input_value(1), + ngraph::opset5::Constant::create(ngraph::element::i64, {non_batch_dims.size()}, non_batch_dims), + ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector{0})); + auto new_reshape_pattern = std::make_shared(OutputVector{new_output_batch, non_batch_dims_node}, 0); reshape->input(1).replace_source_output(new_reshape_pattern->output(0)); - return true; - }; - auto m = std::make_shared(reshape_label, "MimicSetBatchSize"); - register_matcher(m, callback); -} - - -void set_folding_for_PriorBox(std::shared_ptr prior_box, bool flag) { - std::string rt_info_disable_cf = "DISABLED_CONSTANT_FOLDING"; - static std::unordered_set allowed_to_skip = { - ngraph::opset1::Convert::type_info, - ngraph::opset1::StridedSlice::type_info, - }; - static std::unordered_set types_to_find = { - ngraph::opset1::ShapeOf::type_info, - ngraph::opset3::ShapeOf::type_info, - }; - - std::deque> nodes; - nodes.push_back(prior_box->get_input_node_shared_ptr(0)); - nodes.push_back(prior_box->get_input_node_shared_ptr(1)); - - while (!nodes.empty()) { - auto curr_node = nodes.front(); - nodes.pop_front(); - if (allowed_to_skip.count(curr_node->get_type_info())) { - nodes.push_back(curr_node->get_input_node_shared_ptr(0)); - } else if (types_to_find.count(curr_node->get_type_info())) { - auto& rt_info = curr_node->get_rt_info(); - if (flag && rt_info.count(rt_info_disable_cf)) - rt_info.erase(rt_info_disable_cf); - if (!flag) - rt_info[rt_info_disable_cf]; - } + transformed = true; } + return transformed; } - -NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableCFForPriorBoxes, "DisableCFForPriorBoxes", 0); - -bool ngraph::pass::DisableCFForPriorBoxes::run_on_function(std::shared_ptr f) { - for (const auto & node : f->get_ops()) - if (ngraph::is_type(node) || ngraph::is_type(node)) { - set_folding_for_PriorBox(node, false); - } - return false; -} - -NGRAPH_RTTI_DEFINITION(ngraph::pass::EnableCFForPriorBoxes, "EnableCFForPriorBoxes", 0); - -bool ngraph::pass::EnableCFForPriorBoxes::run_on_function(std::shared_ptr f) { - for (const auto & node : f->get_ops()) - if (ngraph::is_type(node) || ngraph::is_type(node)) { - set_folding_for_PriorBox(node, true); - } - return false; -} - diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/set_batch_size.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/set_batch_size.cpp index b048840..3d67cac 100644 --- a/inference-engine/src/transformations/src/transformations/smart_reshape/set_batch_size.cpp +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/set_batch_size.cpp @@ -22,15 +22,9 @@ bool ngraph::pass::SetBatchSize::run_on_function(std::shared_ptr(); - - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.run_passes(f); return true; diff --git a/inference-engine/src/transformations/src/transformations/smart_reshape/strided_slice_squeeze.cpp b/inference-engine/src/transformations/src/transformations/smart_reshape/strided_slice_squeeze.cpp index 6cdbf23..cdddd6e 100644 --- a/inference-engine/src/transformations/src/transformations/smart_reshape/strided_slice_squeeze.cpp +++ b/inference-engine/src/transformations/src/transformations/smart_reshape/strided_slice_squeeze.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include #include @@ -10,9 +10,8 @@ #include #include #include -#include -NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceSqueeze, "StridedSliceSqueeze", 0); +NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceSqueeze, "ngraph::pass::StridedSliceSqueeze", 0); ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() { auto ss_label = ngraph::pattern::wrap_type(pattern::consumers_count(1)); @@ -21,17 +20,10 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() { matcher_pass_callback callback = [](pattern::Matcher &m) -> bool { const auto & squeeze = m.get_match_root(); const auto & const_axes = std::dynamic_pointer_cast(squeeze->get_input_node_shared_ptr(1)); - auto slice = std::dynamic_pointer_cast(squeeze->get_input_node_shared_ptr(0)); if (!const_axes || !slice) return false; - const auto & slice_plan = get_slice_plan(slice); - if (slice_plan.begins.empty() || slice_plan.reshape_in_shape != slice_plan.reshape_out_shape || !slice_plan.reverse_axes.empty()) - return false; - - const auto & axes = normalize_axes(squeeze->description(), const_axes->cast_vector(), squeeze->get_input_partial_shape(0).rank()); - auto begin = std::dynamic_pointer_cast(slice->input_value(1).get_node_shared_ptr()); auto end = std::dynamic_pointer_cast(slice->input_value(2).get_node_shared_ptr()); auto strides = std::dynamic_pointer_cast(slice->input_value(3).get_node_shared_ptr()); @@ -47,17 +39,28 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() { auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector(begin_mask.size(), 0) : slice->get_shrink_axis_mask(); auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector(begin_mask.size(), 0) : slice->get_ellipsis_mask(); + auto is_zero_vec = [](const std::vector & mask){ return std::all_of(mask.begin(), mask.end(), [](const int64_t& i){ return i == 0; }); }; + if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask)) + return false; + if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i){ return i == 1; })) + return false; + + const auto & axes = normalize_axes(squeeze->description(), const_axes->cast_vector(), squeeze->get_input_partial_shape(0).rank()); for (const auto & axis : axes) { - if ((slice_plan.ends[axis] - slice_plan.begins[axis]) != 1 && slice_plan.strides[axis] == 1) - return false; - begin_vec[axis] = slice_plan.begins[axis]; - end_vec[axis] = slice_plan.ends[axis]; - strides_vec[axis] = 1; - begin_mask[axis] = 0; - end_mask[axis] = 0; - new_axis_mask[axis] = 0; + if (begin_mask[axis]) { // corresponding dimension of the begin input is ignored. starting from 0 + begin_vec[axis] = 0; + end_vec[axis] = 1; + begin_mask[axis] = 0; + end_mask[axis] = 0; + } else { // corresponding dimension of the begin input is used for slicing start + if (begin_vec[axis] == -1) { // slicing the latest slice + end_mask[axis] = 1; + } else { + end_vec[axis] = begin_vec[axis] + 1; + end_mask[axis] = 0; + } + } shrink_axis_mask[axis] = 1; - ellipsis_mask[axis] = 0; } auto new_slice = std::make_shared( @@ -72,10 +75,10 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() { copy_runtime_info(slice, new_slice); return true; }; - auto m = std::make_shared(squeeze_label, "StridedSliceSqueeze"); + auto m = std::make_shared(squeeze_label, "ngraph::pass::StridedSliceSqueeze"); register_matcher(m, callback); } -NGRAPH_RTTI_DEFINITION(ngraph::pass::SqueezeStridedSlice, "SqueezeStridedSlice", 0); +NGRAPH_RTTI_DEFINITION(ngraph::pass::SqueezeStridedSlice, "ngraph::pass::SqueezeStridedSlice", 0); ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() { auto squeeze_label = ngraph::pattern::wrap_type( @@ -89,12 +92,6 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() { if (!const_axes || !slice) return false; - const auto & slice_plan = get_slice_plan(slice); - if (slice_plan.begins.empty() || slice_plan.reshape_in_shape != slice_plan.reshape_out_shape || !slice_plan.reverse_axes.empty()) - return false; - - auto axes = normalize_axes(squeeze->description(), const_axes->cast_vector(), squeeze->get_input_partial_shape(0).rank()); - std::sort(axes.begin(), axes.end()); auto begin = std::dynamic_pointer_cast(slice->input_value(1).get_node_shared_ptr()); auto end = std::dynamic_pointer_cast(slice->input_value(2).get_node_shared_ptr()); auto strides = std::dynamic_pointer_cast(slice->input_value(3).get_node_shared_ptr()); @@ -110,6 +107,14 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() { auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector(begin_mask.size(), 0) : slice->get_shrink_axis_mask(); auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector(begin_mask.size(), 0) : slice->get_ellipsis_mask(); + auto is_zero_vec = [](const std::vector & mask){ return std::all_of(mask.begin(), mask.end(), [](const int64_t& i){ return i == 0; }); }; + if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask)) + return false; + if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i){ return i == 1; })) + return false; + + auto axes = normalize_axes(squeeze->description(), const_axes->cast_vector(), squeeze->get_input_partial_shape(0).rank()); + std::sort(axes.begin(), axes.end()); for (const auto & axis : axes) { begin_vec.insert(begin_vec.begin() + axis, 0); end_vec.insert(end_vec.begin() + axis, 1); @@ -133,13 +138,13 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() { copy_runtime_info(slice, new_slice); return true; }; - auto m = std::make_shared(ss_label, "SqueezeStridedSlice"); + auto m = std::make_shared(ss_label, "ngraph::pass::SqueezeStridedSlice"); register_matcher(m, callback); } -NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedSqueeze, "SharedSqueeze", 0); +NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedSqueeze, "ngraph::pass::SharedSqueeze", 0); -bool squeezes_perform_the_same(std::shared_ptr lhs, std::shared_ptr rhs) { +bool squeezes_perform_the_same(std::shared_ptr lhs, std::shared_ptr rhs) { size_t l_input_size = lhs->inputs().size(), r_input_size = rhs->inputs().size(); if (l_input_size != r_input_size) return false; @@ -148,8 +153,8 @@ bool squeezes_perform_the_same(std::shared_ptr lhs, std const auto rank = lhs->get_input_partial_shape(0).rank(); if (rank.is_dynamic()) return false; - const auto l_axes = std::dynamic_pointer_cast(lhs->get_input_node_shared_ptr(1)); - const auto r_axes = std::dynamic_pointer_cast(rhs->get_input_node_shared_ptr(1)); + const auto l_axes = std::dynamic_pointer_cast(lhs->get_input_node_shared_ptr(1)); + const auto r_axes = std::dynamic_pointer_cast(rhs->get_input_node_shared_ptr(1)); if (l_axes && r_axes) return normalize_axes(lhs->description(), l_axes->cast_vector(), rank) == normalize_axes(rhs->description(), r_axes->cast_vector(), rank); @@ -161,7 +166,7 @@ bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr, std::vector>> source_to_squeeze; + std::map, std::vector>> source_to_squeeze; for (const auto & node : f->get_ordered_ops()) { // Recursively apply transformation for sub-graph based operations if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { @@ -169,7 +174,7 @@ bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr(node)) { + if (auto squeeze = std::dynamic_pointer_cast(node)) { source_to_squeeze[squeeze->input_value(0)].push_back(squeeze); } } -- 2.7.4