nGraph Transformations refactoring (#931)
authorGleb Kazantaev <gleb.kazantaev@intel.com>
Mon, 27 Jul 2020 16:47:37 +0000 (19:47 +0300)
committerGitHub <noreply@github.com>
Mon, 27 Jul 2020 16:47:37 +0000 (19:47 +0300)
This PR introduces next changes:
1. Transformations *_tbl.hpp files were replaced with direct registration in cpp files.
2. Plugins use pass::Manager to call conversion passes.
3. Transformations callback was moved to PassBase class as there is no more need to keep it in separate class
4. All pattern based transformations must be inherited from MatcherPass class. GraphRewrite class will be used only for matchers registration and execution on function.
MatcherPass class adds new features to pattern-based transformations approach:
* Allows to run matcher pass on a single node.
* Operations that were created inside transformation callback can be added to execution list to be available for pattern matching within single GraphRewrite.
5. GraphRewrite MatchClosure was replaced with MatcherPass. So all matchers will be registered as a MatcherPass.
6. Added pass::Manager::clear_state() method to avoid dependency with nodes that no longer belongs to function after replacement.
7.  Some representative transformations were updated to use MatcherPass as an example.
8.  Mul->Add sequence fusion transformation was replaced with LinOpSequenceFusion.
9. Pattern and callback registration code was moved to class c-tors (will be finished for remaining passes in other PR) .
10. Updated pass::Manager to get pass names only when NGRAPH_PROFILE_PASS_ENABLE enabled.
11. Moving towards removing PassProperty.
12. Added ngraph::pattern::wrap_type<T>(inputs, pred) to simplify pattern creation.
13. GraphRewrite was updated to execute MatcherPass more efficient.

154 files changed:
docs/IE_PLUGIN_DG/NewTransformation.md
docs/examples/example_ngraph_utils.cpp
docs/template_plugin/src/template_executable_network.cpp
docs/template_plugin/src/template_function_transformation.cpp
docs/template_plugin/src/template_function_transformation.hpp
docs/template_plugin/src/template_pattern_transformation.cpp
docs/template_plugin/src/template_pattern_transformation.hpp
inference-engine/src/cldnn_engine/cldnn_engine.cpp
inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
inference-engine/src/transformations/include/transformations/batch_norm_decomposition.hpp
inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations.hpp
inference-engine/src/transformations/include/transformations/constant_eltwise_reduction.hpp [deleted file]
inference-engine/src/transformations/include/transformations/convert_batch_to_space.hpp
inference-engine/src/transformations/include/transformations/convert_broadcast_to_tiles.hpp
inference-engine/src/transformations/include/transformations/convert_depth_to_space.hpp
inference-engine/src/transformations/include/transformations/convert_divide.hpp
inference-engine/src/transformations/include/transformations/convert_gelu.hpp
inference-engine/src/transformations/include/transformations/convert_minimum_to_power_and_max.hpp
inference-engine/src/transformations/include/transformations/convert_mod.hpp
inference-engine/src/transformations/include/transformations/convert_negative.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/conv_bias_fusion.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_convolutions.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_gathertree_to_gathertree_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_interpolate_to_interp_or_resample.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_lrn_to_lrn_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_matmul_to_fc_or_gemm.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_normalizel2_to_normalize_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_one_hot_to_one_hot_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy_tbl.hpp [deleted file]
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_pad_to_pad_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_power_to_power_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_prelu_to_relu_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_proposal_to_proposal_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_sqrt_to_power_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/fc_bias_fusion.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/reshape_1d_ops.hpp
inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp
inference-engine/src/transformations/include/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp
inference-engine/src/transformations/include/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1_tbl.hpp [deleted file]
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_broadcast3.hpp
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_nms3.hpp
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp [deleted file]
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_shuffle_channels3.hpp
inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_topk3.hpp
inference-engine/src/transformations/include/transformations/convert_reduce_to_pooling.hpp
inference-engine/src/transformations/include/transformations/convert_space_to_batch.hpp
inference-engine/src/transformations/include/transformations/convert_space_to_depth.hpp
inference-engine/src/transformations/include/transformations/convert_subtract.hpp
inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp
inference-engine/src/transformations/include/transformations/lin_op_sequence_fusoin.hpp [new file with mode: 0644]
inference-engine/src/transformations/include/transformations/pull_transpose_through_fq.hpp
inference-engine/src/transformations/include/transformations/utils/pass_param.hpp [deleted file]
inference-engine/src/transformations/src/transformations/batch_norm_decomposition.cpp
inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp
inference-engine/src/transformations/src/transformations/constant_eltwise_reduction.cpp [deleted file]
inference-engine/src/transformations/src/transformations/convert_batch_to_space.cpp
inference-engine/src/transformations/src/transformations/convert_broadcast_to_tiles.cpp
inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp
inference-engine/src/transformations/src/transformations/convert_divide.cpp
inference-engine/src/transformations/src/transformations/convert_minimum_to_power_and_max.cpp
inference-engine/src/transformations/src/transformations/convert_mod.cpp
inference-engine/src/transformations/src/transformations/convert_negative.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/conv_bias_fusion.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_convolutions.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_gather_to_gather_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_gathertree_to_gathertree_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_gelu.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_interpolate_to_interp_or_resample.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_lrn_to_lrn_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_matmul_to_fc_or_gemm.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_normalizel2_to_normalize_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_one_hot_to_one_hot_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_pad_to_pad_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_power_to_power_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_prelu_to_relu_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_proposal_to_proposal_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_sqrt_to_power_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/fc_bias_fusion.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/reshape_1d_ops.cpp
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/reshape_fully_connected.cpp
inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp
inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp
inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_shuffle_channels3.cpp
inference-engine/src/transformations/src/transformations/convert_space_to_batch.cpp
inference-engine/src/transformations/src/transformations/convert_space_to_depth.cpp
inference-engine/src/transformations/src/transformations/convert_subtract.cpp
inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp
inference-engine/src/transformations/src/transformations/lin_op_sequence_fusion.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/pull_transpose_through_fq.cpp
inference-engine/src/vpu/graph_transformer/src/frontend/frontend.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_cells_to_cells_ie_test.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_divide.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_gather_to_gather_ie.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_matmul_test.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_nms4_test.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_nms_to_nms_ie_test.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_strided_slice_to_crop_test.cpp
inference-engine/tests/functional/inference_engine/transformations/convert_topk_test.cpp
inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp
inference-engine/tests/functional/inference_engine/transformations/fc_bias_fusion_test.cpp
inference-engine/tests/functional/inference_engine/transformations/lin_op_sequence_fusion_test.cpp [new file with mode: 0644]
inference-engine/tests/functional/inference_engine/transformations/ngraph_depth_to_space_transform_test.cpp
inference-engine/tests/functional/inference_engine/transformations/ngraph_fq_transpose_test.cpp
inference-engine/tests/functional/inference_engine/transformations/ngraph_mode_decomposition_test.cpp
inference-engine/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/layer_transformation.cpp
inference-engine/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/layer_transformation.cpp
inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp
inference-engine/tests/unit/inference_engine/CMakeLists.txt
ngraph/src/ngraph/CMakeLists.txt
ngraph/src/ngraph/node.cpp
ngraph/src/ngraph/pass/constant_folding.cpp
ngraph/src/ngraph/pass/constant_folding_split.cpp
ngraph/src/ngraph/pass/constant_folding_variadic_split.cpp
ngraph/src/ngraph/pass/graph_rewrite.cpp
ngraph/src/ngraph/pass/graph_rewrite.hpp
ngraph/src/ngraph/pass/manager.cpp
ngraph/src/ngraph/pass/manager.hpp
ngraph/src/ngraph/pass/pass.cpp
ngraph/src/ngraph/pass/pass.hpp
ngraph/src/ngraph/pass/visualize_tree.cpp
ngraph/src/ngraph/pattern/matcher.cpp
ngraph/src/ngraph/pattern/matcher.hpp
ngraph/src/ngraph/pattern/op/label.cpp
ngraph/src/ngraph/pattern/op/label.hpp
ngraph/src/ngraph/pattern/op/pattern.cpp
ngraph/src/ngraph/pattern/op/pattern.hpp
ngraph/src/ngraph/pattern/op/wrap_type.cpp [new file with mode: 0644]
ngraph/src/ngraph/pattern/op/wrap_type.hpp [new file with mode: 0644]
ngraph/test/CMakeLists.txt
ngraph/test/constant_folding.cpp
ngraph/test/graph_rewrite.cpp [new file with mode: 0644]
ngraph/test/matcher_pass.cpp [new file with mode: 0644]
ngraph/test/pattern.cpp

index 896150c..55074e0 100644 (file)
@@ -2,11 +2,15 @@
 
 This guide contains all necessary information that could help you to start writing nGraph transformations.
 
-First of all before writing transformation make sure that there is no transformation with the same functionality in [Transformation Library](group__ie__transformation__api.html).
-To start writing transformation it's good to know how [Transformation Library](group__ie__transformation__api.html) is structured, how transformations are organized and where to put your transformation code.
+First of all before writing transformation make sure that there is no transformation with the same functionality
+in [Transformation Library](group__ie__transformation__api.html). To start writing transformation it's good to know
+how [Transformation Library](group__ie__transformation__api.html) is structured, how transformations are organized
+and where to put your transformation code.
 
 Let's start from reviewing transformations library structure.
-Transformations library is independent from InferenceEngine target library named as `inference_engine_transformations` and located in `inference-engine/src/transformations` directory.
+Transformations library is independent from InferenceEngine target library named as `inference_engine_transformations`
+and located in `inference-engine/src/transformations` directory.
+
 Transformations root directory contains two folders:
 1. ngraph_ops - legacy opset operations needed for nGraph to CNNNetwork conversion.
 > **Note**: this operation are prohibited to use inside new plugins until they are not moved to separate directory with allowed operations.
@@ -14,9 +18,11 @@ Transformations root directory contains two folders:
 > **Note**: do not use transformation that belongs to `ngraph::pass::ConvertOpSet1ToLegacy` transformations until they are not moved to separate directory with allowed transformations.
 
 Transformation flow in transformation library has several layers:
-1. Pass managers - executes list of transformations using `*_tbl.hpp` file. For example conversion form OpSetX to OpSetY.
-2. Transformations - performs particular transformation algorithm on `ngraph::Function`. Find more about transformations in [Transformations types](#transformations_types).
-3. Low level functions that takes set of nodes and performs some transformation action. They are not mandatory and all transformation code can be located inside transformation. But if some transformation parts can potentially be reused in other transformations we suggest to keep them as a separate functions.
+1. Pass managers - executes any type of transformations and provides additional debug capabilities.
+2. Transformations - performs particular transformation algorithm on `ngraph::Function`.
+3. Low level functions that takes set of nodes and performs some transformation action. 
+They are not mandatory and all transformation code can be located inside transformation.
+But if some transformation parts can potentially be reused in other transformations we suggest to keep them as a separate functions.
 
 To decide where to store your transformation code please follow these rules:
 1. If it's plugin specific transformation and can't be reused by other plugins keep source code inside plugin.
@@ -26,16 +32,19 @@ After you decided where to store your transformation code you can start develop
 
 ## Table of Contents:
 
-1. [`ngraph::Function` and graph representation](#ngraph_function) 
-2. [Transformations types](#transformations_types)
-3. [Pattern matching](#pattern_matching)
-4. [Working with ngraph::Function](#working_with_ngraph_function)
-5. [Transformation writing essentials](#transformation_writing_essentials)
-6. [Common mistakes in transformations](#common_mistakes)
-7. [Using pass manager](#using_pass_manager)
-8. [How to debug transformations](#how_to_debug_transformations)
-9. [Disabling/Enabling specific transformations for plugin X](#disabling_transformation)
-10. [Transformations testing](#transformations_testing)
+### 1. [`ngraph::Function` and graph representation](#ngraph_function) 
+### 2. [Transformations types](#transformations_types)
+### 2.1 [Function pass](#function_pass)
+### 2.2 [Matcher pass](#matcher_pass)
+### 2.3 [GraphRewrite pass](#graph_rewrite_pass) 
+### 3. [Pattern matching](#pattern_matching)
+### 4. [Working with ngraph::Function](#working_with_ngraph_function)
+### 5. [Transformation writing essentials](#transformation_writing_essentials)
+### 6. [Common mistakes in transformations](#common_mistakes)
+### 7. [Using pass manager](#using_pass_manager)
+### 8. [How to debug transformations](#how_to_debug_transformations)
+### 9. [Disabling/Enabling specific transformations for plugin X](#disabling_transformation)
+### 10. [Transformations testing](#transformations_testing)
 
 ## ngraph::Function and graph representation <a name="ngraph_function"></a>
 
@@ -51,11 +60,12 @@ Below you can find examples how `ngraph::Function` can be created:
 
 ## Transformations types <a name="transformations_types"></a>
 
-There are two main transformation types:
+nGraph has tree main transformation types: `ngraph::pass::FunctionPass` - strait forward way to work with `ngraph::Function` directly;
+`ngraph::pass::MatcherPass` - pattern based transformation approach; `ngraph::pass::GraphRewrite` - container for matcher passes.
 
-###1. ngraph::pass::FunctionalPass
+###1. ngraph::pass::FunctionPass <a name="function_pass"></a>
 
-ngraph::pass::FunctionalPass is used for transformations that take entire `ngraph::Function` as input and process it.
+`ngraph::pass::FunctionPass` is used for transformations that take entire `ngraph::Function` as input and process it.
 
 Template for FunctionPass transformation class
 
@@ -63,65 +73,92 @@ Template for FunctionPass transformation class
 
 @snippet src/template_function_transformation.cpp function_pass:template_transformation_cpp
 
-Using `ngraph::FunctionPass` you need to override `run_on_function` method where you will write transformation code. Return value must be `true` if original function has changed during transformation (new operation were added or operations replacement was made or node attributes were changed) otherwise it must be `false`. For transformation API please follow [working with ngraph::Function](#working_with_ngraph_function) section.
+Using `ngraph::FunctionPass` you need to override `run_on_function` method where you will write transformation code.
+Return value must be `true` if original function has changed during transformation (new operation were added or operations replacement was made or node attributes were changed) otherwise it must be `false`.
+For transformation API please follow [working with ngraph::Function](#working_with_ngraph_function) section.
+Also `ngraph::FunctionPass` based transformations can be executed via `pass::Manager`. See examples in [Using pass manager](#using_pass_manager) section.
 
-###2. ngraph::pass::GraphRewrite
+###2. ngraph::pass::MatcherPass <a name="matcher_pass"></a>
 
-`ngraph::pass::GraphRewrite` is used for pattern based transformations.
+`ngraph::pass::MatcherPass` is used for pattern based transformations.
 
-Template for GraphRewrite transformation class
+Template for MatcherPass transformation class
 @snippet src/template_pattern_transformation.hpp graph_rewrite:template_transformation_hpp
 
 @snippet src/template_pattern_transformation.cpp graph_rewrite:template_transformation_cpp
 
-Using `ngraph::GraphRewrite` you need to complete three steps:
-1. Create pattern using nGraph operations.
-2. Implement callback. 
-3. Register pattern and Matcher.
+Using `ngraph::pass::MatcherPass` you need to complete these steps:
+1. Create pattern
+2. Implement callback 
+3. Register pattern and Matcher
+4. MatcherPass execution
 
 So let's go though each of this steps.
 
-Pattern is a single root `ngraph::Function`. But the only difference is that you don't need to create function object, you just create and connect nGraph operations then take the last one and put it as a root of the pattern.
+### Create pattern
+Pattern is a single root `ngraph::Function`. But the only difference is that you don't need to create function object, you just create and connect nGraph or special pattern operations.
+And then take the last created operation and put it as a root of the pattern. This root node will be used as a root node in pattern matching.
+> **Note**: any nodes in pattern that have no consumers and not registered as root won't be used in pattern matching. 
 
 @snippet example_ngraph_utils.cpp pattern:simple_example
 
-You may have noticed that `Parameter` operation in example has type and shape specified. These attributes are needed only to create Parameter operation class and not used in pattern matching. 
-But what if we want to match pattern where `ShapeOf` takes any operation as input? To find an answer to this question please follow [pattern matching](#pattern_matching) section.
+You may have noticed that `Parameter` operation in example has type and shape specified. These attributes are needed only to create Parameter operation class and won't be used in pattern matching. 
+
+But what if we want to match pattern where `ShapeOf` takes any operation as input? To find an answer please follow [pattern matching](#pattern_matching) section.
 
-What is callback? Callback is an action applied to every pattern entrance. In general callback is lambda function that takes Matcher object with detected sub-graph.
+### Implement callback
+Callback is an action applied to every pattern entrance. In general callback is lambda function that takes Matcher object with detected sub-graph.
 
 @snippet example_ngraph_utils.cpp pattern:callback_example
 
 Example above shows callback structure and how Matcher can be used for accessing nodes detected by pattern.
-Callback return value must be `true` if root node was replaced and next pattern can't be applied to the same root node otherwise it must be `false`.
+Callback return value must be `true` if root node was replaced and another pattern can't be applied to the same root node otherwise it must be `false`.
+> **Note**: it's not recommended to manipulate with nodes that are under root node. This may affect GraphRewrite execution as it's expected that all nodes that comes after root node in topological order are valid and can be used in pattern matching. 
+
+MatcherPass also provides functionality that allows to report which newly created nodes can be used in additional pattern matching.
+If MatcherPass was registered in `pass::Manager` or `pass::GraphRewrite` then this registered nodes will be added for additional pattern matching.
+That means that matcher passes registered in `pass::GraphRewrite` will be applied to this nodes.
+
+Example below shows how single MatcherPass can fuse sequence of operations using `register_new_node` method.
+
+@snippet src/template_pattern_transformation.cpp matcher_pass:relu_fusion
+
+> **Note**: if you register multiple nodes please add them in topological order. We do not topologically sort this nodes as it's time consuming operation.
 
-And the last step is to register Matcher and callback inside GraphRewrite pass. And to do this you need to call `add_matcher` method. 
+### Register pattern and Matcher
+The last step is to register Matcher and callback inside MatcherPass pass. And to do this you need to call `register_matcher` method.
+> **Note**: Only one matcher can be registered for single MatcherPass class.
 
 ```cpp
 // Register matcher and callback
-this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+this->register_matcher(m, callback);
 ```
+### Matcher pass execution
+MatcherPass has multiple ways to be executed:
+1. Run on a single node - it can be useful if you want to run MatcherPass inside another transformation.
+@snippet src/template_pattern_transformation.cpp matcher_pass:run_on_node
+2. Run on `ngraph::Function` using GraphRewrite - this approach gives ability to run MatcherPass on whole `ngraph::Functoin`. Moreover multiple MatcherPass transformation can be registered in a single GraphRewite to be executed in a single graph traversal.
+@snippet src/template_pattern_transformation.cpp matcher_pass:graph_rewrite
+3. Run on `ngraph::Function` using `pass::Manager` - this approach helps you to register MatcherPass for execution on `ngraph::Function` as another transformation types.
+@snippet src/template_pattern_transformation.cpp matcher_pass:manager
 
-Also you can have multiple matchers and callbacks and they can be registered in single Graphrewrite pass. In this case all registered patterns will be applied in a singe graph traversal. 
 
-```cpp
-// Multiple matchers example
-this->add_matcher(m1, callback1, PassProperty::CHANGE_DYNAMIC_STATE);
-this->add_matcher(m2, callback2, PassProperty::CHANGE_DYNAMIC_STATE);
-```
+###3. ngraph::pass::GraphRewrite <a name="graph_rewrite_pass"></a>
 
-The last argument `PassProperty::CHANGE_DYNAMIC_STATE` says that callback can be applied for ngraph::Function with dynamic shapes. In case if callback does not support dynamic shapes `PassProperty::REQUIRE_STATIC_SHAPE` can be used.
-> **Note**: property mechanism will be deprecated soon and PassProperty::CHANGE_DYNAMIC_STATE is suggested to be used by default.
+GraphRewrite pass serves for running multiple matcher passes on `ngraph::Function` in a single graph traversal. 
+Example:
+
+@snippet src/template_pattern_transformation.cpp matcher_pass:graph_rewrite
+
+In addition GraphRewrite handles nodes that were registered by MatcherPasses during their execution. This nodes will be added to the beginning of sequence with nodes for pattern matching.
+
+> **Note**: when using `pass::Manager` temporary GraphRewrite is used to execute single MatcherPass. 
 
-To run any transformation you need to call `un_on_function(f)` method where `f` is `ngraph::Function`.
-```cpp
-ngraph::pass::MyTransformationClass().run_on_function(f);
-``` 
-  
 ## Pattern matching <a name="pattern_matching"></a>
 
 Sometimes patterns can't be expressed via regular nGraph operations. For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
 And for these cases nGraph provides additional helpers to construct patterns for GraphRewrite transformations. 
+
 There are two main helpers:
 1. `ngraph::pattern::op::Label` - helps to express inputs if their type is undefined.
 2. `ngraph::pattern::op::Any` - helps to express intermediate nodes of pattern if their type is unknown.
@@ -139,11 +176,11 @@ This example show how we can construct pattern when operation has arbitrary numb
 
 @snippet example_ngraph_utils.cpp pattern:concat_example
 
-This example shows how to use predicate to construct pattern where operation has two different types.
+This example shows how to use predicate to construct pattern where operation has two different types. Also it shows how to match pattern manually on given node.
 
 @snippet example_ngraph_utils.cpp pattern:predicate_example
 
-TODO: add examples for ngraph::pattern::op::Any
+> **Note**: be careful with manual matching because Matcher object holds matched nodes. To clear match use m->clear_state() method.
 
 ## Working with ngraph::Function <a name="working_with_ngraph_function"></a>
 
@@ -306,22 +343,21 @@ In transformation development process
 * If you replace node with another node that produce different shape you need to remember that new shape won't be propagated until first `validate_nodes_and_infer_types` call for `ngraph::Function`. If you are using `pass::Manager` it will automatically call this method after each transformation execution.
 * Do not forget to call `ngraph::ConstantFolding` pass if your transformation creates constant sub-graphs.
 * Use latest OpSet if you are not developing downgrade transformation pass.
+* When developing callback for `ngraph::pass::MatcherPass` do not change nodes that comes after root node in topological order. 
 
 ## Using pass manager <a name="using_pass_manager"></a>
 
 `ngraph::pass::Manager` is a container class that can store list of transformations and execute them. The main idea of this class is to have high-level representation for grouped list of transformations.
-For example `ngraph::pass::CommonOptimizations` pass manager register list of transformation related to common optimizations. Also `ngraph::pass::Manager` after each transformation executes `f->validate_nodes_and_infer_types()` that help to keep function synchronized.
+It can register and apply any [transformation types](#transformations_types) on function.
 In addition `ngraph::pass::Manager` has extended debug capabilities (find more information in [how to debug transformations](#how_to_debug_transformations) section). 
 
 Example below shows basic usage of `ngraph::pass::Manager`
-```cpp
-ngraph::pass::Manager pass_manager;
-pass_manager.register_pass<pass::MyTransformationA>();
-pass_manager.register_pass<pass::MyTransformationB>();
-pass_manager.run_passes(f);
-```
 
-TODO: Advanced pass manager usage.
+@snippet src/template_pattern_transformation.cpp matcher_pass:manager3
+
+Another example how multiple matcher passes can be united into single GraphRewrite.
+
+@snippet src/template_pattern_transformation.cpp matcher_pass:manager2
 
 ## How to debug transformations <a name="how_to_debug_transformations"></a>
 
@@ -353,26 +389,31 @@ NGRAPH_ENABLE_VISUALIZE_TRACING=1 -  enables visualization after each transforma
 
 This topic mostly related to conversion to legacy opset and plugins that based on CNNNetwork but still this mechanism can be applied for other cases.
 Let's suppose that plugin X enabled `opset3::StridedSlice` operation support and you want to disable `ngraph::pass::ConvertStridedSliceToCrop` transformation for plugin X.
-To do this you need to extend transformation class with `ngraph::pass::PassParam` class. This class extends transformations class with `transformation_callback` that can be set by plugin that uses legacy conversion. 
+To do this you need to create callback on plugin side and pass it to transformation. And also you need to update particular transformation to use this callback.  
 
 ```cpp
-// Extend transformation class with PassParam
-class ngraph::pass::ConvertStridedSliceToCrop: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
-    ...
-}
-
-// Update callback to be able to use transformation_callback if this transformation based on GraphRewrite.
+// Update callback to be able to use m_transformation_callback if this transformation based on GraphRewrite.
 ngraph::graph_rewrite_callback callback = [this](pattern::Matcher &m) {
     ...
 }
 
-// Use transformation_callback not to execute transformation
-if (transformation_callback(node)) {
+// Use transformation_callback not to execute transformation if callback returns true for given node
+if (m_transformation_callback(node)) {
     return false;
 }
-```
 
-TODO: link to existing example
+// Implement transformation callback and pass it directly to transformation or pass::Manager
+const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
+    return std::dynamic_pointer_cast<const ::ngraph::opset3::StridedSlice>(node) != nullptr;
+};
+
+// Register transformation and pass callback to pass::Manager
+ngraph::pass::Manager manager;
+manager.register_pass<ngraph::pass::ConvertStridedSliceToCrop>();
+// pass::Manager will set callback to all reistered transformations automatically
+manager.set_callback(transformations_callback);
+manager.run_passes(f);
+```
 
 ## Transformations testing <a name="transformations_testing"></a>
 
@@ -384,7 +425,6 @@ The basic transformation test looks like this:
 
 @snippet tests/functional/transformations/template_transformations_test.cpp transformation:test
 
-TODO: insert advanced transformation tests
 
 [ngraph_replace_node]: ../images/ngraph_replace_node.png
 [ngraph_insert_node]: ../images/ngraph_insert_node.png
\ No newline at end of file
index 93d6b7b..e46b040 100644 (file)
@@ -60,7 +60,7 @@ std::shared_ptr<ngraph::Function> create_advanced_function() {
 }
 // ! [ngraph_utils:advanced_function]
 
-void pattern_matcher_examples() {
+void pattern_matcher_examples(std::shared_ptr<Node> node) {
 {
 // ! [pattern:simple_example]
 // Pattern example
@@ -110,9 +110,13 @@ auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMatcher");
 auto lin_op = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{},
               [](const std::shared_ptr<ngraph::Node> & node) -> bool {
                     return std::dynamic_pointer_cast<ngraph::opset3::Multiply>(node) ||
-                    std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
+                           std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
               });
 auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op, "MultiplyOrAddMatcher");
+// Matcher can be used to match pattern manually on given node
+if (m->match(node->output(0))) {
+    // Successfully matched
+}
 // ! [pattern:predicate_example]
 }
 
@@ -232,15 +236,17 @@ bool success = replace_output_update_name(node->output(0), node->input_value(0))
 
 // ! [ngraph:visualize]
 void visualization_example(std::shared_ptr<ngraph::Function> f) {
-    std::vector<std::shared_ptr<ngraph::Function> > g{f};
+    ngraph::pass::Manager manager;
 
     // Serialize ngraph::Function to before.svg file before transformation
-    ngraph::pass::VisualizeTree("/path/to/file/before.svg").run_on_module(g);
+    manager.register_pass<ngraph::pass::VisualizeTree>("/path/to/file/before.svg");
 
     // Run your transformation
-    // ngraph::pass::MyTransformation().run_on_function();
+    // manager.register_pass<ngraph::pass::MyTransformation>();
 
     // Serialize ngraph::Function to after.svg file after transformation
-    ngraph::pass::VisualizeTree("/path/to/file/after.svg").run_on_module(g);
+    manager.register_pass<ngraph::pass::VisualizeTree>("/path/to/file/after.svg");
+
+    manager.run_passes(f);
 }
 // ! [ngraph:visualize]
index d0acc26..db3da88 100644 (file)
@@ -86,7 +86,8 @@ void TemplatePlugin::ExecutableNetwork::CompileGraph(const std::shared_ptr<const
     // Example: register CommonOptimizations transformation from transformations library
     passManager.register_pass<ngraph::pass::CommonOptimizations>();
     // Example: register plugin specific transformation
-    passManager.register_pass<ngraph::pass::MyPatternBasedTransformation>();
+    passManager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
+    passManager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
     // Register any other transformations
     // ..
 
index 16db736..671c76d 100644 (file)
@@ -10,7 +10,7 @@ using namespace ngraph;
 
 // ! [function_pass:template_transformation_cpp]
 // template_function_transformation.cpp
-bool MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
+bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
     // Example transformation code
     std::vector<std::shared_ptr<Node> > nodes;
 
index c8cf88a..398aa3f 100644 (file)
@@ -9,9 +9,17 @@
 
 #include <ngraph/ngraph.hpp>
 
+namespace ngraph {
+namespace pass {
+
+class MyFunctionTransformation;
+
+}  // namespace pass
+}  // namespace ngraph
+
 // ! [function_pass:template_transformation_hpp]
 // template_function_transformation.hpp
-class MyFunctionTransformation: public ngraph::pass::FunctionPass {
+class ngraph::pass::MyFunctionTransformation: public ngraph::pass::FunctionPass {
 public:
     MyFunctionTransformation() : FunctionPass() {}
 
index 6ac0579..1e2060f 100644 (file)
@@ -3,21 +3,23 @@
 //
 
 #include "template_pattern_transformation.hpp"
+#include "template_function_transformation.hpp"
 
-#include <ngraph/opsets/opset3.hpp>
 #include <ngraph/ngraph.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 using namespace ngraph;
 
 // ! [graph_rewrite:template_transformation_cpp]
 // template_pattern_transformation.cpp
-void ngraph::pass::MyPatternBasedTransformation::transform() {
+ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
     // Pattern example
-    auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
+    auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
+    auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
     auto div = std::make_shared<ngraph::opset3::Divide>(input0, input1);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto div = std::dynamic_pointer_cast<ngraph::opset3::Divide> (m.get_match_root());
         // We can not apply this transformation in case with integer input data type
         if (!div || div->input(0).get_element_type().is_integral()) {
@@ -43,9 +45,94 @@ void ngraph::pass::MyPatternBasedTransformation::transform() {
         return true;
     };
 
-    // Register pattern with divide operaiton as a pattern root node
+    // Register pattern with Divide operation as a pattern root node
     auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
     // Register Matcher
-    this->add_matcher(m, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 // ! [graph_rewrite:template_transformation_cpp]
+
+// ! [matcher_pass:relu_fusion]
+ngraph::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
+    auto m_relu1 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>(pattern::consumers_count(1));
+    auto m_relu2 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>({m_relu1});
+
+    ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
+        // Map that helps to connect labels with matched outputs
+        auto& node_to_output = m.get_pattern_value_map();
+
+        // Create new Relu operation and add register it for additional execution
+        auto new_relu = register_new_node<ngraph::opset3::Relu>(
+            node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));
+
+        // Copy runtime info attributes to newly created operation
+        ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu);
+
+        // Save last Relu name to new Relu operation
+        new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());
+
+        // Replace Relu->Relu with Relu
+        ngraph::replace_node(m.get_match_root(), new_relu);
+
+        // Return true as the root node was changed
+        return true;
+    };
+
+    // Register pattern with Relu operation as a pattern root node
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
+    // Register Matcher
+    this->register_matcher(m, callback);
+}
+// ! [matcher_pass:relu_fusion]
+
+void run_matcher_on_node(std::shared_ptr<ngraph::Node> node) {
+// ! [matcher_pass:run_on_node]
+if (ngraph::pass::DecomposeDivideMatcher().apply(node)) {
+    // successful execution (root node was replaced)
+}
+// ! [matcher_pass:run_on_node]
+}
+
+void run_matcher_with_manager(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:manager]
+// Two matchers will run independently (two independent graph traversals)
+// pass::Manager automatically creates GraphRewrite container for each MatcherPass
+pass::Manager manager;
+manager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
+manager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
+manager.run_passes(f);
+// ! [matcher_pass:manager]
+}
+
+void run_matcher_with_manager2(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:manager2]
+// Register anchor GraphRewrite pass inside manager that will execute two matchers simultaneously
+pass::Manager manager;
+auto anchor = manager.register_pass<ngraph::pass::GraphRewrite>();
+anchor->add_matcher<ngraph::pass::DecomposeDivideMatcher>();
+anchor->add_matcher<ngraph::pass::ReluReluFusionMatcher>();
+manager.run_passes(f);
+// ! [matcher_pass:manager2]
+}
+
+void run_matcher_with_manager3(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:manager3]
+pass::Manager manager;
+manager.register_pass<ngraph::pass::MyFunctionTransformation>();
+// Two matchers will run independently (two independent graph traversals)
+// pass::Manager automatically creates GraphRewrite container for each MatcherPass
+manager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
+manager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
+manager.run_passes(f);
+// ! [matcher_pass:manager3]
+}
+
+void run_matcher_with_gr(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:graph_rewrite]
+// Two matcher passes will run simultaneously in a single graph traversal
+ngraph::pass::GraphRewrite pass;
+pass.add_matcher<ngraph::pass::DecomposeDivideMatcher>();
+pass.add_matcher<ngraph::pass::ReluReluFusionMatcher>();
+pass.run_on_function(f);
+// ! [matcher_pass:graph_rewrite]
+}
index 9c9bc6f..0220a6d 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class MyPatternBasedTransformation;
+class DecomposeDivideMatcher;
+class ReluReluFusionMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
 // ! [graph_rewrite:template_transformation_hpp]
 // template_pattern_transformation.hpp
-class ngraph::pass::MyPatternBasedTransformation: public ngraph::pass::GraphRewrite {
+class ngraph::pass::DecomposeDivideMatcher: public ngraph::pass::MatcherPass {
 public:
-    MyPatternBasedTransformation() : GraphRewrite() {
-        transform();
-    }
-
-private:
-    void transform();
+    DecomposeDivideMatcher();
 };
 // ! [graph_rewrite:template_transformation_hpp]
+
+class ngraph::pass::ReluReluFusionMatcher: public ngraph::pass::MatcherPass {
+public:
+    ReluReluFusionMatcher();
+};
index a189b3c..0ecdc5b 100644 (file)
@@ -25,6 +25,7 @@
 #include <ngraph/opsets/opset2.hpp>
 #include <ngraph/opsets/opset3.hpp>
 #include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/pass/manager.hpp>
 #include <generic_ie.hpp>
 #include <transformations/common_optimizations/common_optimizations.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
@@ -95,10 +96,14 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneNetwork(const InferenceEngin
         ::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
 
         // Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
-        ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::CommonOptimizations>();
+        manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+        manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+        manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+        manager.set_callback(transformations_callback);
+        manager.run_passes(nGraphFunc);
         clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
     }
 
index f93ed91..7b1a6b5 100644 (file)
@@ -15,6 +15,7 @@
 #include <vector>
 #include <unordered_set>
 #include <ngraph/ngraph.hpp>
+#include <ngraph/pass/manager.hpp>
 #include <ngraph/pass/get_output_element_elimination.hpp>
 #include <set>
 #include <string>
@@ -307,7 +308,9 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
             // Call this transformation because OneHot IE and nGraph have different output precisions
             {
                 IE_PROFILING_AUTO_SCOPE(ConvertOneHot);
-                ::ngraph::pass::ConvertOneHotToOneHotIE().run_on_function(specialized_ngraph_function);
+                ::ngraph::pass::Manager manager;
+                manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(specialized_ngraph_function);
+                manager.run_passes(specialized_ngraph_function);
             }
             specialized_ngraph_function->validate_nodes_and_infer_types();
 
index 6d29e13..f702f7b 100644 (file)
@@ -27,6 +27,7 @@
 #include <ngraph/opsets/opset3.hpp>
 #include <ngraph/op/fused/gelu.hpp>
 #include <ngraph/op/util/op_types.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "ngraph_ops/fully_connected.hpp"
 
 #if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
@@ -77,10 +78,15 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) {
     ::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
 
     // Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
-    ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
-    ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
-    ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
-    ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+    ngraph::pass::Manager manager;
+    manager.register_pass<ngraph::pass::CommonOptimizations>();
+    manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+    manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+    manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+    manager.set_callback(transformations_callback);
+    manager.run_passes(nGraphFunc);
+
     clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
 }
 
index a37d9cb..15d48b5 100644 (file)
@@ -22,12 +22,7 @@ class TRANSFORMATIONS_API BatchNormDecomposition;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::BatchNormDecomposition: public ngraph::pass::GraphRewrite {
+class ngraph::pass::BatchNormDecomposition: public ngraph::pass::MatcherPass {
 public:
-    BatchNormDecomposition() : GraphRewrite() {
-        batch_norm_decomposition();
-    }
-
-private:
-    void batch_norm_decomposition();
+    BatchNormDecomposition();
 };
index cfbabbe..764e8c6 100644 (file)
@@ -11,7 +11,6 @@
 
 #include <ngraph/pass/graph_rewrite.hpp>
 
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -21,10 +20,7 @@ class TRANSFORMATIONS_API CommonOptimizations;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::CommonOptimizations: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::CommonOptimizations: public ngraph::pass::FunctionPass {
 public:
-    explicit CommonOptimizations(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
-            : FunctionPass(), PassParam(callback) {}
-
     bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
 };
diff --git a/inference-engine/src/transformations/include/transformations/constant_eltwise_reduction.hpp b/inference-engine/src/transformations/include/transformations/constant_eltwise_reduction.hpp
deleted file mode 100644 (file)
index f334f69..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright (C) 2018-2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#pragma once
-
-#include <transformations_visibility.hpp>
-
-#include <ngraph/pass/graph_rewrite.hpp>
-
-namespace ngraph {
-namespace pass {
-
-class TRANSFORMATIONS_API ConstantEltwiseReduction;
-
-}  // namespace pass
-}  // namespace ngraph
-
-class ngraph::pass::ConstantEltwiseReduction: public ngraph::pass::GraphRewrite {
-public:
-    ConstantEltwiseReduction() : GraphRewrite() {
-        constant_multiply_reduction();
-        constant_add_reduction();
-    }
-
-private:
-    void constant_multiply_reduction();
-    void constant_add_reduction();
-};
index 3044246..0ba5ef7 100644 (file)
@@ -11,7 +11,6 @@
 
 #include <ngraph/ops.hpp>
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -21,9 +20,9 @@ class TRANSFORMATIONS_API ConvertBatchToSpace;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertBatchToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam  {
+class ngraph::pass::ConvertBatchToSpace: public ngraph::pass::GraphRewrite {
 public:
-    ConvertBatchToSpace() : GraphRewrite(), PassParam() {
+    ConvertBatchToSpace() : GraphRewrite() {
         // convert_batch_to_space();
         convert_batch_to_space_ie_side();
     }
index 156512c..15e313d 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API ConvertBroadcastToTiles;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertBroadcastToTiles: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertBroadcastToTiles: public ngraph::pass::MatcherPass {
 public:
-    ConvertBroadcastToTiles() : GraphRewrite() {
-        convert_broadcast_to_tiles();
-    }
-
-private:
-    void convert_broadcast_to_tiles();
+    ConvertBroadcastToTiles();
 };
index d12af18..1cdc993 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,12 +19,7 @@ class TRANSFORMATIONS_API ConvertDepthToSpace;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::MatcherPass {
 public:
-    ConvertDepthToSpace() : GraphRewrite(), PassParam() {
-        convert_depth_to_space();
-    }
-
-private:
-    void convert_depth_to_space();
+    ConvertDepthToSpace();
 };
index d4fc07f..5eeddf8 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API ConvertDivide;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertDivide: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertDivide: public ngraph::pass::MatcherPass {
 public:
-    ConvertDivide() : GraphRewrite() {
-        convert_divide();
-    }
-
-private:
-    void convert_divide();
+    ConvertDivide();
 };
index cbc3add..67f68cb 100644 (file)
@@ -12,7 +12,6 @@
 #include <ngraph/pass/graph_rewrite.hpp>
 
 #include "ngraph/op/fused/gelu.hpp"
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -22,9 +21,9 @@ class TRANSFORMATIONS_API ConvertGELU;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertGELU: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertGELU: public ngraph::pass::GraphRewrite {
 public:
-    ConvertGELU() : GraphRewrite(), PassParam() {
+    ConvertGELU() : GraphRewrite() {
         convert_gelu();
     }
 
index 4fb9bc8..1caa819 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API ConvertMinimum;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertMinimum: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertMinimum: public ngraph::pass::MatcherPass {
 public:
-    ConvertMinimum() : GraphRewrite() {
-        convert_minimum();
-    }
-
-private:
-    void convert_minimum();
+    ConvertMinimum();
 };
index a2d49ae..5010fa5 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API ConvertMod;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertMod: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertMod: public ngraph::pass::MatcherPass {
 public:
-    ConvertMod() : GraphRewrite() {
-        convert_mod();
-    }
-
-private:
-    void convert_mod();
+    ConvertMod();
 };
index 8d0ab32..ccc5b58 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API ConvertNegative;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertNegative: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNegative: public ngraph::pass::MatcherPass {
 public:
-    ConvertNegative() : GraphRewrite() {
-        convert_negative();
-    }
-
-private:
-    void convert_negative();
+    ConvertNegative();
 };
index 663202c..7e0047e 100644 (file)
@@ -29,6 +29,9 @@ namespace ngraph {
 namespace pass {
 
 class TRANSFORMATIONS_API ConvFusion;
+class TRANSFORMATIONS_API ConvAddFusion;
+class TRANSFORMATIONS_API ConvMultiplyFusion;
+class TRANSFORMATIONS_API DeconvAddFusion;
 
 }  // namespace pass
 }  // namespace ngraph
@@ -36,112 +39,23 @@ class TRANSFORMATIONS_API ConvFusion;
 class ngraph::pass::ConvFusion: public ngraph::pass::GraphRewrite {
 public:
     ConvFusion() : GraphRewrite() {
-        fuse_convolution_with<op::ConvolutionIE,   opset1::Multiply>();
-        fuse_convolution_with<op::ConvolutionIE,   opset1::Add>();
-        fuse_convolution_with<op::DeconvolutionIE, opset1::Add>();
+        add_matcher<ngraph::pass::ConvAddFusion>();
+        add_matcher<ngraph::pass::ConvMultiplyFusion>();
+        add_matcher<ngraph::pass::DeconvAddFusion>();
     }
-
-private:
-    template <class Conv, class Eltwise>
-    void fuse_convolution_with();
-
-    template <class Conv>
-    ngraph::graph_rewrite_callback get_callback();
 };
 
-template <class Conv, class Eltwise>
-void ngraph::pass::ConvFusion::fuse_convolution_with() {
-    static_assert(std::is_same<Eltwise, ngraph::opset1::Multiply>() || std::is_same<Eltwise, ngraph::opset1::Add>(),
-                  "This transformation works only with ngraph::opset1::Add and ngraph::opset1::Multiply");
-
-    static_assert(std::is_same<Conv, ngraph::op::ConvolutionIE>() || std::is_same<Conv, ngraph::op::DeconvolutionIE>(),
-                  "This transformation works only with ngraph::op::ConvolutionIE and ngraph::op::DeconvolutionIE");
-
-    auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
-            [](const std::shared_ptr<Node> & node) -> bool {
-                return std::dynamic_pointer_cast<ngraph::op::ConvolutionIE>(node) ||
-                       std::dynamic_pointer_cast<ngraph::op::DeconvolutionIE>(node);
-    });
-
-    auto last = std::make_shared<Eltwise>(conv, std::make_shared<pattern::op::Label>(element::f32, Shape{1}));
-
-    auto m = std::make_shared<ngraph::pattern::Matcher>(last, "ConvFusion");
-    this->add_matcher(m, get_callback<Conv>(), PassProperty::CHANGE_DYNAMIC_STATE);
-}
-
-template <class Conv>
-ngraph::graph_rewrite_callback ngraph::pass::ConvFusion::get_callback() {
-    ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
-        auto eltwise = m.get_match_root();
-
-        std::shared_ptr<op::Constant> m_const;
-        std::shared_ptr<Conv> m_conv;
-        // FIXME: use auto [m_conv, m_const] when C++17 is available
-        std::tie(m_conv, m_const) = parse_eltwise_inputs<Conv, op::Constant>(eltwise);
-        if (!m_conv || !m_const) {
-            return false;
-        }
-
-        // TODO: check that constant can be scalar and do not match [1, C, 1, 1] layout
-        const auto constant_shape = m_const->get_shape();
-        const auto output_pshape = m_conv->get_output_partial_shape(0);
-
-        if (output_pshape.rank().is_dynamic() || output_pshape[1].is_dynamic()) {
-            return false;
-        }
-
-        const auto channel_dim = output_pshape[1].get_length();
-
-        size_t constant_size = std::accumulate(constant_shape.begin(), constant_shape.end(), 1, std::multiplies<size_t>());
-        if (constant_size != channel_dim) {
-            return false;
-        }
-
-        Output<Node> constant(m_const);
-
-        if (constant_shape.size() > 1) {
-            constant = std::make_shared<opset1::Reshape>(constant, op::Constant::create(element::i64, Shape{1}, {channel_dim}), true);
-        }
-
-        if (m_conv->output(0).get_target_inputs().size() != 1) {
-            return false;
-        }
-
-        Output<Node> new_conv, new_weights, new_bias;
-        if (std::dynamic_pointer_cast<opset1::Add>(eltwise)) {
-            // Fuse: ConvolutionIE/DeconvolutionIE->Add
-            if (m_conv->inputs().size() == 2) {
-                new_bias = constant;
-            } else {
-                new_bias = std::make_shared<opset1::Add>(constant, m_conv->input_value(2));
-            }
-            new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), m_conv->input_value(1), new_bias});
-        } else if (std::is_same<Conv, op::ConvolutionIE>() && std::dynamic_pointer_cast<opset1::Multiply>(eltwise)) {
-            // Fuse: ConvolutionIE->Mul
-            auto weights_shape = m_conv->input(1).get_shape();
-
-            Shape const_shape(weights_shape.size(), 1);
-            const_shape[0] = weights_shape[0];
-
-            auto const_reshape = std::make_shared<opset1::Reshape>(constant,
-                                                                   op::Constant::create(element::i64, Shape{const_shape.size()}, const_shape), true);
-            new_weights = std::make_shared<opset1::Multiply> (m_conv->input_value(1), const_reshape);
-            if (m_conv->inputs().size() == 2) {
-                new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights});
-            } else {
-                auto bias_reshape = std::make_shared<opset1::Reshape>(constant, op::Constant::create(element::i64, Shape{1}, {weights_shape[0]}), true);
-                new_bias = std::make_shared<opset1::Multiply>(bias_reshape, constant);
-                new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights, new_bias});
-            }
-        } else {
-            return false;
-        }
+class ngraph::pass::ConvAddFusion: public ngraph::pass::MatcherPass {
+public:
+    ConvAddFusion();
+};
 
-        ngraph::copy_runtime_info({m_conv, eltwise}, new_conv.get_node_shared_ptr());
-        new_conv.get_node_shared_ptr()->set_friendly_name(m.get_match_root()->get_friendly_name());
-        ngraph::replace_node(m.get_match_root(), new_conv.get_node_shared_ptr());
-        return true;
-    };
-    return callback;
-}
+class ngraph::pass::ConvMultiplyFusion: public ngraph::pass::MatcherPass {
+public:
+    ConvMultiplyFusion();
+};
 
+class ngraph::pass::DeconvAddFusion: public ngraph::pass::MatcherPass {
+public:
+    DeconvAddFusion();
+};
\ No newline at end of file
index fb0d777..ffd9a40 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertCellsToCellsIE;
+class TRANSFORMATIONS_API ConvertLSTMCellMatcher;
+class TRANSFORMATIONS_API ConvertGRUCellMatcher;
+class TRANSFORMATIONS_API ConvertRNNCellMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertCellsToCellsIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertLSTMCellMatcher : public ngraph::pass::MatcherPass {
 public:
-    ConvertCellsToCellsIE() : GraphRewrite() {
-        convert_lstm_cell();
-        convert_gru_cell();
-        convert_rnn_cell();
-    }
-
-private:
-    void convert_lstm_cell();
-    void convert_gru_cell();
-    void convert_rnn_cell();
+    ConvertLSTMCellMatcher();
+};
+
+class ngraph::pass::ConvertGRUCellMatcher : public ngraph::pass::MatcherPass {
+public:
+    ConvertGRUCellMatcher();
+};
+
+class ngraph::pass::ConvertRNNCellMatcher : public ngraph::pass::MatcherPass {
+public:
+    ConvertRNNCellMatcher();
 };
index 5ef8f8c..11d7533 100644 (file)
@@ -16,24 +16,40 @@ namespace pass {
 
 class TRANSFORMATIONS_API ConvertConvolutions;
 
+class TRANSFORMATIONS_API ConvertConvolution;
+class TRANSFORMATIONS_API ConvertGroupConvolution;
+class TRANSFORMATIONS_API ConvertDeconvolution;
+class TRANSFORMATIONS_API ConvertGroupDeconvolution;
+
 }  // namespace pass
 }  // namespace ngraph
 
 class ngraph::pass::ConvertConvolutions: public ngraph::pass::GraphRewrite {
 public:
-    ConvertConvolutions() : GraphRewrite() {
-        convert_convolution();
-        convert_group_convolution();
-        convert_convolution_backprop_data();
-        convert_group_convolution_backprop_data();
+    ConvertConvolutions() {
+        add_matcher<ngraph::pass::ConvertConvolution>();
+        add_matcher<ngraph::pass::ConvertGroupConvolution>();
+        add_matcher<ngraph::pass::ConvertDeconvolution>();
+        add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
     }
+};
 
-private:
-    void convert_convolution();
-
-    void convert_group_convolution();
+class ngraph::pass::ConvertConvolution: public ngraph::pass::MatcherPass {
+public:
+    ConvertConvolution();
+};
 
-    void convert_convolution_backprop_data();
+class ngraph::pass::ConvertGroupConvolution: public ngraph::pass::MatcherPass {
+public:
+    ConvertGroupConvolution();
+};
 
-    void convert_group_convolution_backprop_data();
+class ngraph::pass::ConvertDeconvolution: public ngraph::pass::MatcherPass {
+public:
+    ConvertDeconvolution();
 };
+
+class ngraph::pass::ConvertGroupDeconvolution: public ngraph::pass::MatcherPass {
+public:
+    ConvertGroupDeconvolution();
+};
\ No newline at end of file
index 6dad3a4..129f580 100644 (file)
@@ -22,7 +22,7 @@
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertGatherToGatherIE;
+class TRANSFORMATIONS_API ConvertGatherToGatherIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
@@ -34,12 +34,7 @@ class TRANSFORMATIONS_API ConvertGatherToGatherIE;
  *     we unsqueeze indices input and squeeze GatherIE output.
  */
 
-class ngraph::pass::ConvertGatherToGatherIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertGatherToGatherIEMatcher : public ngraph::pass::MatcherPass {
 public:
-    ConvertGatherToGatherIE() : GraphRewrite() {
-        convert_gather_to_gather_ie();
-    }
-
-private:
-    void convert_gather_to_gather_ie();
+    ConvertGatherToGatherIEMatcher();
 };
index 8cccbef..448f677 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertGatherTreeToGatherTreeIE;
+class TRANSFORMATIONS_API ConvertGatherTreeToGatherTreeIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertGatherTreeToGatherTreeIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertGatherTreeToGatherTreeIE() : GraphRewrite() {
-        convert();
-    }
-
-private:
-    void convert();
+    ConvertGatherTreeToGatherTreeIEMatcher();
 };
index 72046b2..edc7980 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertHardSigmoidToHardSigmoidIE;
+class TRANSFORMATIONS_API ConvertHardSigmoidToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertHardSigmoidToHardSigmoidIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertHardSigmoidToLegacyMatcher : public ngraph::pass::MatcherPass {
 public:
-    ConvertHardSigmoidToHardSigmoidIE() : GraphRewrite() {
-        convert_hard_sigmoid();
-    }
-
-private:
-    void convert_hard_sigmoid();
+    ConvertHardSigmoidToLegacyMatcher();
 };
index 771c21f..cecf293 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertInterpolateToInterpOrResample;
+class TRANSFORMATIONS_API ConvertInterpolateToInterpOrResampleMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertInterpolateToInterpOrResample: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertInterpolateToInterpOrResample() : GraphRewrite() {
-        convert_interpolate_to_interp_or_resample();
-    }
-
-private:
-    void convert_interpolate_to_interp_or_resample();
+    ConvertInterpolateToInterpOrResampleMatcher();
 };
index 46dd002..195f931 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertLRNToLRNIE;
+class TRANSFORMATIONS_API ConvertLRNToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertLRNToLRNIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertLRNToLegacyMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertLRNToLRNIE() : GraphRewrite() {
-        convert_lrn();
-    }
-
-private:
-    void convert_lrn();
+    ConvertLRNToLegacyMatcher();
 };
index f8c9530..e9a1c6a 100644 (file)
@@ -23,12 +23,7 @@ class TRANSFORMATIONS_API ConvertMatMulToFCorGemm;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::MatcherPass {
 public:
-    ConvertMatMulToFCorGemm(): GraphRewrite() {
-        convert_matmul();
-    }
-
-private:
-    void convert_matmul();
+    ConvertMatMulToFCorGemm();
 };
index e9bccaa..e3d83db 100644 (file)
@@ -14,7 +14,7 @@
 namespace ngraph {
 namespace pass {
 
-    class TRANSFORMATIONS_API ConvertNMS4ToLegacy;
+    class TRANSFORMATIONS_API ConvertNMS4ToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
@@ -25,12 +25,8 @@ namespace pass {
  */
 
 
-class ngraph::pass::ConvertNMS4ToLegacy: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNMS4ToLegacyMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertNMS4ToLegacy() : GraphRewrite() {
-        convert_nms4_to_legacy();
-    }
-private:
-    void convert_nms4_to_legacy();
+    ConvertNMS4ToLegacyMatcher();
 };
 
index 8eacc44..c11711d 100644 (file)
@@ -15,7 +15,7 @@
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertNMSToNMSIE;
+class TRANSFORMATIONS_API ConvertNMSToNMSIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
@@ -28,12 +28,7 @@ class TRANSFORMATIONS_API ConvertNMSToNMSIE;
  *     we insert Unsqueeze operations.
  */
 
-class ngraph::pass::ConvertNMSToNMSIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNMSToNMSIEMatcher : public ngraph::pass::MatcherPass {
 public:
-    ConvertNMSToNMSIE() : GraphRewrite() {
-        convert_nms_to_nms_ie();
-    }
-
-private:
-    void convert_nms_to_nms_ie();
+    ConvertNMSToNMSIEMatcher();
 };
index 6dcb86c..9796cc5 100644 (file)
@@ -16,27 +16,17 @@ namespace ngraph {
 namespace pass {
 
 class TRANSFORMATIONS_API ConvertNormalizeL2WithMulToNormalizeIE;
-class TRANSFORMATIONS_API ConvertNormalizeL2ToNormalizeIE;
+class TRANSFORMATIONS_API ConvertNormalizeL2ToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE: public ngraph::pass::MatcherPass {
 public:
-    ConvertNormalizeL2WithMulToNormalizeIE() : GraphRewrite() {
-        convert_normalize_l2_with_mul();
-    }
-
-private:
-    void convert_normalize_l2_with_mul();
+    ConvertNormalizeL2WithMulToNormalizeIE();
 };
 
-class ngraph::pass::ConvertNormalizeL2ToNormalizeIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNormalizeL2ToLegacyMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertNormalizeL2ToNormalizeIE() : GraphRewrite() {
-        convert_normalize_l2();
-    }
-
-private:
-    void convert_normalize_l2();
+    ConvertNormalizeL2ToLegacyMatcher();
 };
index 1a68761..5bdec78 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertOneHotToOneHotIE;
+class TRANSFORMATIONS_API ConvertOneHotToOneHotIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertOneHotToOneHotIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertOneHotToOneHotIEMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertOneHotToOneHotIE() : GraphRewrite(), is_f16(false) {
-        convert_one_hot();
-    }
+    ConvertOneHotToOneHotIEMatcher();
 
-    bool run_on_function(std::shared_ptr<ngraph::Function> f) final;
+    void detect_output_type(const std::shared_ptr<Function> & f);
 
 private:
-    void convert_one_hot();
-    bool is_f16;
-};
+    element::Type m_output_type = element::Type_t::f32;
+};
\ No newline at end of file
index 11ba9d5..3ef021d 100644 (file)
@@ -11,7 +11,6 @@
 
 #include <ngraph/pass/graph_rewrite.hpp>
 
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -21,10 +20,7 @@ class TRANSFORMATIONS_API ConvertOpSet1ToLegacy;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertOpSet1ToLegacy: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertOpSet1ToLegacy: public ngraph::pass::FunctionPass {
 public:
-    explicit ConvertOpSet1ToLegacy(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
-             : FunctionPass(), PassParam(callback) {}
-
     bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
 };
diff --git a/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy_tbl.hpp b/inference-engine/src/transformations/include/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy_tbl.hpp
deleted file mode 100644 (file)
index be03dba..0000000
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#ifndef NGRAPH_PASS
-#warning "NGRAPH_PASS is not defined"
-#define NGRAPH_PASS(A, B)
-#endif
-
-// To register new pass you need to define NGRAPH_PASS
-// Usage example:
-//   ngraph::pass:Manager pm;
-//   #define NGRAPH_PASS(NAME, NAMESPACE)   pm.register_pass<NAMESPACE::NAME>();
-//   #include <transformations/transformations_tbl.hpp>
-//   #undef NGRAPH_PASS
-
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertReduceToPooling, ::ngraph::pass)
-NGRAPH_PASS(ConvertMod, ::ngraph::pass)
-NGRAPH_PASS(ConvertMinimum, ::ngraph::pass)
-NGRAPH_PASS(ConvertSubtract, ::ngraph::pass)
-NGRAPH_PASS(ConvertDivide, ::ngraph::pass)
-NGRAPH_PASS(ConvertNegative, ::ngraph::pass)
-NGRAPH_PASS(ConvertDepthToSpace, ::ngraph::pass)
-NGRAPH_PASS(ConvertSpaceToDepth, ::ngraph::pass)
-NGRAPH_PASS(ConvertConvolutions, ::ngraph::pass)
-NGRAPH_PASS(BatchNormDecomposition, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(MulAddVerification, ::ngraph::pass)
-NGRAPH_PASS(MulAddFusion, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertMatMulToFCorGemm, ::ngraph::pass)
-NGRAPH_PASS(PullTransposeThroughFQUp, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvFusion, ::ngraph::pass)
-NGRAPH_PASS(FullyConnectedBiasFusion, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ReshapeFullyConnected, ::ngraph::pass)
-NGRAPH_PASS(ReshapeFullyConnectedFusion, ::ngraph::pass)
-NGRAPH_PASS(Reshape1DOps, ::ngraph::pass)
-NGRAPH_PASS(ConvertNormalizeL2WithMulToNormalizeIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertNormalizeL2ToNormalizeIE, ::ngraph::pass)
-NGRAPH_PASS(ConstantEltwiseReduction, ::ngraph::pass)
-NGRAPH_PASS(ConvertMulAddToScaleShiftOrPower, ::ngraph::pass)
-NGRAPH_PASS(ConvertMulOrAddFinally, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertBroadcastToTiles, ::ngraph::pass)
-NGRAPH_PASS(ConvertTileToIETile, ::ngraph::pass)
-NGRAPH_PASS(ConvertProposalToProposalIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertLRNToLRNIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertPadToPadIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertHardSigmoidToHardSigmoidIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertCellsToCellsIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertInterpolateToInterpOrResample, ::ngraph::pass)
-NGRAPH_PASS(ConvertStridedSliceToCrop, ::ngraph::pass)
-NGRAPH_PASS(ConvertPowerToPowerIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertSqrtToPowerIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertPReLUToReLUIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertGatherToGatherIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertSeluToSeluIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertOneHotToOneHotIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertGatherTreeToGatherTreeIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertTopKToTopKIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertNMSToNMSIE, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertNMS4ToLegacy, ::ngraph::pass)
index 7bc622a..cbc373f 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertPadToPadIE;
+class TRANSFORMATIONS_API ConvertPadToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertPadToPadIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertPadToLegacyMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertPadToPadIE() : GraphRewrite() {
-        convert_pad();
-    }
-
-private:
-    void convert_pad();
+    ConvertPadToLegacyMatcher();
 };
index 826c1a3..b090375 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertPowerToPowerIE;
+class TRANSFORMATIONS_API ConvertPowerToPowerIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertPowerToPowerIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertPowerToPowerIEMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertPowerToPowerIE() : GraphRewrite() {
-        convert_power();
-    }
-
-private:
-    void convert_power();
+    ConvertPowerToPowerIEMatcher();
 };
index 42e06ab..f1668b7 100644 (file)
@@ -20,12 +20,7 @@ class TRANSFORMATIONS_API ConvertPReLUToReLUIE;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertPReLUToReLUIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertPReLUToReLUIE: public ngraph::pass::MatcherPass {
 public:
-    ConvertPReLUToReLUIE() : GraphRewrite() {
-        convert_prelu();
-    }
-
-private:
-    void convert_prelu();
+    ConvertPReLUToReLUIE();
 };
index 3eea831..835cbe5 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertProposalToProposalIE;
+class TRANSFORMATIONS_API ConvertProposalToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertProposalToProposalIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertProposalToLegacyMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertProposalToProposalIE() : GraphRewrite() {
-        convert_proposal();
-    }
-
-private:
-    void convert_proposal();
+    ConvertProposalToLegacyMatcher();
 };
index 1649765..3d191a2 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertSeluToSeluIE;
+class TRANSFORMATIONS_API ConvertSeluToSeluIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertSeluToSeluIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSeluToSeluIEMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertSeluToSeluIE() : GraphRewrite() {
-        convert_selu();
-    }
-
-private:
-    void convert_selu();
+    ConvertSeluToSeluIEMatcher();
 };
index 017ce77..38d1acb 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertSqrtToPowerIE;
+class TRANSFORMATIONS_API ConvertSqrtToPowerIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertSqrtToPowerIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSqrtToPowerIEMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertSqrtToPowerIE() : GraphRewrite() {
-        convert_sqrt();
-    }
-
-private:
-    void convert_sqrt();
+    ConvertSqrtToPowerIEMatcher();
 };
 
index 08bfabf..8c19740 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertStridedSliceToCrop;
+class TRANSFORMATIONS_API ConvertStridedSliceToCropMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertStridedSliceToCrop: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertStridedSliceToCropMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertStridedSliceToCrop() : GraphRewrite() {
-        convert_strided_slice_to_crop();
-    }
-
-private:
-    void convert_strided_slice_to_crop();
+    ConvertStridedSliceToCropMatcher();
 };
index ca76848..69b3f26 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertTileToIETile;
+class TRANSFORMATIONS_API ConvertTileToLegacyMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertTileToIETile: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertTileToLegacyMatcher: public ngraph::pass::MatcherPass {
 public:
-    ConvertTileToIETile() : GraphRewrite() {
-        convert_tile();
-    }
-
-private:
-    void convert_tile();
+    ConvertTileToLegacyMatcher();
 };
index b374985..7d78ab0 100644 (file)
 namespace ngraph {
 namespace pass {
 
-class TRANSFORMATIONS_API ConvertTopKToTopKIE;
+class TRANSFORMATIONS_API ConvertTopKToTopKIEMatcher;
 
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertTopKToTopKIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertTopKToTopKIEMatcher : public ngraph::pass::MatcherPass {
 public:
-    ConvertTopKToTopKIE() : GraphRewrite() {
-        convert_topk_to_topk_ie();
-    }
-
-private:
-    void convert_topk_to_topk_ie();
+    ConvertTopKToTopKIEMatcher();
 };
index 9312a25..36132cb 100644 (file)
@@ -31,80 +31,7 @@ class TRANSFORMATIONS_API FullyConnectedBiasFusion;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::FullyConnectedBiasFusion : public ngraph::pass::GraphRewrite {
+class ngraph::pass::FullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
 public:
-    FullyConnectedBiasFusion() : GraphRewrite() {
-        construct_fcbias();
-    }
-
-private:
-    void construct_fcbias() {
-        Shape shape_w{2, 4};
-        Shape shape_x{2, 4};
-        Shape shape_b{2, 2};
-        auto input = std::make_shared<pattern::op::Label>(element::f32, shape_w);
-        auto weights = std::make_shared<pattern::op::Label>(element::f32, shape_x);
-        auto fc_bias = std::make_shared<pattern::op::Label>(element::f32, shape_b);
-        auto bias = std::make_shared<pattern::op::Label>(element::f32, shape_b);
-
-        auto fc = std::make_shared<op::FullyConnected>(input, weights, fc_bias, Shape{1, 2});
-        auto add = std::make_shared<opset1::Add>(fc, bias);
-
-        ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
-            auto add = m.get_match_root();
-            auto add_input_0 = add->input(0).get_source_output().get_node_shared_ptr();
-            auto add_input_1 = add->input(1).get_source_output().get_node_shared_ptr();
-
-            auto m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_0);
-            auto m_bias = add_input_1;
-
-            if (m_fc == nullptr) {
-                m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_1);
-                m_bias = add_input_0;
-            }
-
-            if (auto bcast_m = std::dynamic_pointer_cast<opset1::Broadcast>(m_bias)) {
-                m_bias = bcast_m->input(0).get_source_output().get_node_shared_ptr();
-            }
-
-            if (!std::dynamic_pointer_cast<opset1::Constant>(m_bias)) {
-                return false;
-            }
-            Shape bias_shape(m_bias->get_shape());
-
-            if (m_fc->output(0).get_target_inputs().size() != 1) {
-                return false;
-            }
-
-            Shape output_shape(m_fc->get_shape());
-            size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<int64_t>());
-            if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
-                return false;
-            }
-
-            NodeVector new_ops;
-
-            auto new_bias = std::make_shared<opset1::Add>(m_fc->input(2).get_source_output(), m_bias);
-            new_ops.push_back(new_bias);
-            std::shared_ptr<Node> final_bias = new_bias;
-            if (new_bias->get_shape().size() >= 2) {
-                final_bias = std::make_shared<opset1::Reshape>(final_bias, opset1::Constant::create(element::i64, Shape{1}, {-1}), true);
-                new_ops.push_back(final_bias);
-            }
-
-            auto new_fc = std::make_shared<op::FullyConnected>(m_fc->input(0).get_source_output(),
-                                                               m_fc->input(1).get_source_output(),
-                                                               final_bias,
-                                                               m_fc->get_shape());
-            new_ops.push_back(new_fc);
-
-            new_fc->set_friendly_name(add->get_friendly_name());
-            ngraph::copy_runtime_info({m_fc, add}, new_ops);
-            ngraph::replace_node(add, new_fc);
-            return true;
-        };
-
-        auto m = std::make_shared<ngraph::pattern::Matcher>(add, "FullyConnectedBiasFusion");
-        this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
-    }
+    FullyConnectedBiasFusion();
 };
index b5803c1..6f99217 100644 (file)
@@ -15,6 +15,9 @@ namespace ngraph {
 namespace pass {
 
 class TRANSFORMATIONS_API Reshape1DOps;
+class TRANSFORMATIONS_API Reshape1DConvolution;
+class TRANSFORMATIONS_API Reshape1DAvgPool;
+class TRANSFORMATIONS_API Reshape1DMaxPool;
 
 }  // namespace pass
 }  // namespace ngraph
@@ -22,9 +25,23 @@ class TRANSFORMATIONS_API Reshape1DOps;
 class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
 public:
     Reshape1DOps() : GraphRewrite() {
-        reshape_ops();
+        add_matcher<ngraph::pass::Reshape1DConvolution>();
+        add_matcher<ngraph::pass::Reshape1DAvgPool>();
+        add_matcher<ngraph::pass::Reshape1DMaxPool>();
     }
+};
 
-private:
-    void reshape_ops();
+class ngraph::pass::Reshape1DConvolution: public ngraph::pass::MatcherPass {
+public:
+    Reshape1DConvolution();
 };
+
+class ngraph::pass::Reshape1DAvgPool: public ngraph::pass::MatcherPass {
+public:
+    Reshape1DAvgPool();
+};
+
+class ngraph::pass::Reshape1DMaxPool: public ngraph::pass::MatcherPass {
+public:
+    Reshape1DMaxPool();
+};
\ No newline at end of file
index 52bdcef..2a60caa 100644 (file)
@@ -11,7 +11,6 @@
 
 #include <ngraph/pass/graph_rewrite.hpp>
 
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -42,18 +41,9 @@ class TRANSFORMATIONS_API ReshapeFullyConnected;
  *         }
  *     };
  *
- *     auto p = ngraph::pass::ReshapeFullyConnected();
- *     p.setCallback(callback);
- *     p.run_on_function(f);
- *
  */
 
-class ngraph::pass::ReshapeFullyConnected: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ReshapeFullyConnected: public ngraph::pass::MatcherPass {
 public:
-    ReshapeFullyConnected() : GraphRewrite(), PassParam() {
-        reshape_fully_connected();
-    }
-
-private:
-    void reshape_fully_connected();
+    ReshapeFullyConnected();
 };
index 6c3a3f5..647d34c 100644 (file)
@@ -7,7 +7,6 @@
 #include <memory>
 #include <transformations_visibility.hpp>
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -17,10 +16,7 @@ class TRANSFORMATIONS_API ConvertOpSet2ToOpSet1;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertOpSet2ToOpSet1: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertOpSet2ToOpSet1: public ngraph::pass::FunctionPass {
 public:
-    explicit ConvertOpSet2ToOpSet1(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
-             : FunctionPass(), PassParam(callback) {}
-
     bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
 };
diff --git a/inference-engine/src/transformations/include/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1_tbl.hpp b/inference-engine/src/transformations/include/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1_tbl.hpp
deleted file mode 100644 (file)
index d67b909..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#ifndef NGRAPH_PASS
-#warning "NGRAPH_PASS is not defined"
-#define NGRAPH_PASS(A, B)
-#endif
-
-// To register new pass you need to define NGRAPH_PASS
-// Usage example:
-//   ngraph::pass:Manager pm;
-//   #define NGRAPH_PASS(NAME, NAMESPACE)   pm.register_pass<NAMESPACE::NAME>();
-//   #include <transformations/transformations_tbl.hpp>
-//   #undef NGRAPH_PASS
-
-NGRAPH_PASS(ConvertGELU, ::ngraph::pass)
-NGRAPH_PASS(ConvertSpaceToBatch, ::ngraph::pass)
-NGRAPH_PASS(ConvertBatchToSpace, ::ngraph::pass)
index 481b3a8..f72261d 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,7 +19,7 @@ class TRANSFORMATIONS_API ConvertBroadcast3;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite {
 public:
     ConvertBroadcast3() : GraphRewrite() {
         convert_broadcast3();
index 120ac80..9a62c2f 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,7 +19,7 @@ class TRANSFORMATIONS_API ConvertNMS3;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertNMS3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertNMS3: public ngraph::pass::GraphRewrite {
 public:
     ConvertNMS3() : GraphRewrite() {
         convert_nms3();
index 1eafe00..f92d433 100644 (file)
@@ -7,7 +7,6 @@
 #include <memory>
 #include <transformations_visibility.hpp>
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -17,10 +16,7 @@ class TRANSFORMATIONS_API ConvertOpSet3ToOpSet2;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertOpSet3ToOpSet2: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertOpSet3ToOpSet2: public ngraph::pass::FunctionPass {
 public:
-    explicit ConvertOpSet3ToOpSet2(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
-            : FunctionPass(), PassParam(callback) {}
-
     bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
 };
diff --git a/inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp b/inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp
deleted file mode 100644 (file)
index 271c0e9..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#ifndef NGRAPH_PASS
-#warning "NGRAPH_PASS is not defined"
-#define NGRAPH_PASS(A, B)
-#endif
-
-// To register new pass you need to define NGRAPH_PASS
-// Usage example:
-//   ngraph::pass:Manager pm;
-//   #define NGRAPH_PASS(NAME, NAMESPACE)   pm.register_pass<NAMESPACE::NAME>();
-//   #include <transformations/transformations_tbl.hpp>
-//   #undef NGRAPH_PASS
-
-NGRAPH_PASS(ConvertBroadcast3, ::ngraph::pass)
-NGRAPH_PASS(ConvertNMS3, ::ngraph::pass)
-NGRAPH_PASS(ConvertShapeOf3, ::ngraph::pass)
-NGRAPH_PASS(ConvertShuffleChannels3, ::ngraph::pass)
-NGRAPH_PASS(ConvertTopK3, ::ngraph::pass)
index 1438736..369d405 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,9 +19,9 @@ class TRANSFORMATIONS_API ConvertShuffleChannels3;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::GraphRewrite {
 public:
-    ConvertShuffleChannels3() : GraphRewrite(), PassParam() {
+    ConvertShuffleChannels3() : GraphRewrite() {
         convert_shuffle_channels3();
     }
 
index 145bfc5..36b6fb0 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,7 +19,7 @@ class TRANSFORMATIONS_API ConvertTopK3;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertTopK3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertTopK3: public ngraph::pass::GraphRewrite {
 public:
     ConvertTopK3() : GraphRewrite() {
         convert_topk3();
index e58e25c..d52484d 100644 (file)
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/validation_util.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 
 namespace ngraph {
 namespace pass {
 
 class TRANSFORMATIONS_API ConvertReduceToPooling;
+class TRANSFORMATIONS_API ConvertReduceMeanToPooling;
+class TRANSFORMATIONS_API ConvertReduceMaxToPooling;
+class TRANSFORMATIONS_API ConvertReduceSumToPooling;
 
 }  // namespace pass
 }  // namespace ngraph
 
 class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
 public:
-    ConvertReduceToPooling() : GraphRewrite() {
-        convert_reduce_to_pooling<ngraph::opset1::ReduceMean>();
-        convert_reduce_to_pooling<ngraph::opset1::ReduceMax>();
-        convert_reduce_to_pooling<ngraph::opset1::ReduceSum>();
+    ConvertReduceToPooling() {
+        add_matcher<ConvertReduceMeanToPooling>();
+        add_matcher<ConvertReduceMaxToPooling>();
+        add_matcher<ConvertReduceSumToPooling>();
     }
-
-private:
-    template <class T>
-    void convert_reduce_to_pooling();
 };
 
 template <class T>
-void ngraph::pass::ConvertReduceToPooling::convert_reduce_to_pooling() {
-    {
-        static_assert(std::is_same<T, ngraph::opset1::ReduceMean>() ||
-                      std::is_same<T, ngraph::opset1::ReduceMax>()  ||
-                      std::is_same<T, ngraph::opset1::ReduceSum>(),
-                      "This callback works only with ngraph::opset1::ReduceMean/Max/Sum");
-
-        auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-        auto axes = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
-        auto reduce = std::make_shared<T>(data, axes);
-
-        ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
-            auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
-            if (!reduce) {
-                return false;
-            }
+ngraph::matcher_pass_callback convert_reduce_to_pooling();
 
-            auto input = reduce->input_value(0);
+class ngraph::pass::ConvertReduceMeanToPooling: public ngraph::pass::MatcherPass {
+public:
+    ConvertReduceMeanToPooling() {
+        auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMean>(), "ConvertReduceMean");
+        register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>());
+    }
+};
 
-            auto axes_node = reduce->input_value(1).get_node_shared_ptr();
-            if (!ngraph::op::is_constant(axes_node)) {
-                return false;
-            }
+class ngraph::pass::ConvertReduceMaxToPooling: public ngraph::pass::MatcherPass {
+public:
+    ConvertReduceMaxToPooling() {
+        auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMax>(), "ConvertReduceMax");
+        register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>());
+    }
+};
 
-            auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>();
-            const auto input_rank = input.get_partial_shape().rank().get_length();
-            // Transform negative axes into non-negative ones
-            for (size_t i = 0; i < axes_vector.size(); ++i) {
-                if (axes_vector[i] < 0) {
-                    axes_vector[i] += input_rank;
-                }
-            }
-            std::sort(axes_vector.begin(), axes_vector.end());
+class ngraph::pass::ConvertReduceSumToPooling: public ngraph::pass::MatcherPass {
+public:
+    ConvertReduceSumToPooling() {
+        auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceSum>(), "ConvertReduceSum");
+        register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>());
+    }
+};
 
-            // If axes are empty we just remove Reduction operation
-            if (axes_vector.empty()) {
-                return replace_output_update_name(reduce->output(0), input);
+template <class T>
+ngraph::matcher_pass_callback convert_reduce_to_pooling() {
+    return [](ngraph::pattern::Matcher& m) {
+        auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
+        if (!reduce) {
+            return false;
+        }
+
+        auto input = reduce->input_value(0);
+
+        auto axes_node = reduce->input_value(1).get_node_shared_ptr();
+        if (!ngraph::op::is_constant(axes_node)) {
+            return false;
+        }
+
+        auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>();
+        const auto input_rank = input.get_partial_shape().rank().get_length();
+        // Transform negative axes into non-negative ones
+        for (size_t i = 0; i < axes_vector.size(); ++i) {
+            if (axes_vector[i] < 0) {
+                axes_vector[i] += input_rank;
             }
+        }
+        std::sort(axes_vector.begin(), axes_vector.end());
+
+        // If axes are empty we just remove Reduction operation
+        if (axes_vector.empty()) {
+            return replace_output_update_name(reduce->output(0), input);
+        }
+
+        // As this transformation requires static input shape we should guaranty it
+        if (input.get_partial_shape().is_dynamic()) {
+            return false;
+        }
+        auto input_shape = input.get_shape();
+
+        // If Reduce op reduces only 1 dims we replace it with Reshape
+        if (std::all_of(axes_vector.begin(), axes_vector.end(),
+                [&input_shape](const int64_t & axis) { return input_shape[axis] == 1; })) {
+            const auto reshape_shape = reduce->output(0).get_shape();
+            auto reshape = std::make_shared<ngraph::opset1::Reshape>(input,
+                    ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape), true);
+
+            reshape->set_friendly_name(reduce->get_friendly_name());
+            copy_runtime_info(reduce, reshape);
+            replace_node(reduce, reshape);
+            return true;
+        }
 
-            // As this transformation requires static input shape we should guaranty it
-            if (input.get_partial_shape().is_dynamic()) {
+        // Check that axes are consecutive otherwise this transformation is not applicable
+        for (size_t i = 1; i < axes_vector.size(); ++i) {
+            if (axes_vector[i] - axes_vector[i-1] != 1) {
                 return false;
             }
-            auto input_shape = input.get_shape();
-
-            // If Reduce op reduces only 1 dims we replace it with Reshape
-            if (std::all_of(axes_vector.begin(), axes_vector.end(),
-                    [&input_shape](const int64_t & axis) { return input_shape[axis] == 1; })) {
-                const auto reshape_shape = reduce->output(0).get_shape();
-                auto reshape = std::make_shared<ngraph::opset1::Reshape>(input,
-                        ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape), true);
-
-                reshape->set_friendly_name(reduce->get_friendly_name());
-                copy_runtime_info(reduce, reshape);
-                replace_node(reduce, reshape);
-                return true;
+        }
+
+        // Check either reduction applies to spatial dimensions or not
+        bool spatial_dims_reduction(true);
+        size_t reduction_dims_count = 1;
+        for (auto& axis : axes_vector) {
+            reduction_dims_count *= input_shape[axis];
+            if (axis <= 1) {
+                spatial_dims_reduction = false;
             }
-
-            // Check that axes are consecutive otherwise this transformation is not applicable
-            for (size_t i = 1; i < axes_vector.size(); ++i) {
-                if (axes_vector[i] - axes_vector[i-1] != 1) {
-                    return false;
+        }
+
+        /*
+         * Prepare default attributes for Pooling operation
+         *      pads_begin/pads_end - should be zeros as we don't need any padding
+         *      stride - should be filled with ones
+         *      kernel - depends on Reduction operation axes
+         *
+         * Also here we decide should we use Reshapes before and after Pooling
+         *      shape_begin - if not empty indicates that we need a Reshape before Pooling
+         *      shape_end   - if not empty indicates that we need a Reshape after Pooling
+         */
+
+        ngraph::Strides strides;
+        ngraph::Shape pads_begin, pads_end, kernel, shape_begin, shape_end;
+
+        if (!spatial_dims_reduction || input_shape.size() != 4) {
+            // In case if reduction applies not to spatial dimensions
+            // we have to fit it into 4D Pooling
+            size_t dims_prod = 1, dims_begin = 1, dims_end = 1;
+            for (size_t i = 0; i < input_shape.size(); ++i) {
+                if (i < *axes_vector.begin()) {
+                    dims_begin *= input_shape[i];
+                } else if (i >= axes_vector.front() && i <= axes_vector.back()) {
+                    dims_prod *= input_shape[i];
+                } else {
+                    dims_end *= input_shape[i];
                 }
             }
-
-            // Check either reduction applies to spatial dimensions or not
-            bool spatial_dims_reduction(true);
-            size_t reduction_dims_count = 1;
+            // The batch dimenstion is repositioned in the shape
+            // only in case of batch dimension reduction
+            shape_begin.assign({dims_begin, 1, dims_prod, dims_end});
+            shape_end = reduce->output(0).get_shape();
+            strides.assign({1, 1});
+            pads_begin.assign({0, 0});
+            pads_end.assign({0, 0});
+            kernel.assign({dims_prod, 1});
+        } else {
+            for (size_t i = 0; i < input_shape.size() - 2; ++i) {
+                strides.push_back(1);
+                pads_begin.push_back(0);
+                pads_end.push_back(0);
+                kernel.push_back(1);
+            }
             for (auto& axis : axes_vector) {
-                reduction_dims_count *= input_shape[axis];
-                if (axis <= 1) {
-                    spatial_dims_reduction = false;
-                }
+                kernel[axis-2] = input_shape[axis];
             }
-
-            /*
-             * Prepare default attributes for Pooling operation
-             *      pads_begin/pads_end - should be zeros as we don't need any padding
-             *      stride - should be filled with ones
-             *      kernel - depends on Reduction operation axes
-             *
-             * Also here we decide should we use Reshapes before and after Pooling
-             *      shape_begin - if not empty indicates that we need a Reshape before Pooling
-             *      shape_end   - if not empty indicates that we need a Reshape after Pooling
-             */
-
-            Strides strides;
-            Shape pads_begin, pads_end, kernel, shape_begin, shape_end;
-
-            if (!spatial_dims_reduction || input_shape.size() != 4) {
-                // In case if reduction applies not to spatial dimensions
-                // we have to fit it into 4D Pooling
-                size_t dims_prod = 1, dims_begin = 1, dims_end = 1;
-                for (size_t i = 0; i < input_shape.size(); ++i) {
-                    if (i < *axes_vector.begin()) {
-                        dims_begin *= input_shape[i];
-                    } else if (i >= axes_vector.front() && i <= axes_vector.back()) {
-                        dims_prod *= input_shape[i];
-                    } else {
-                        dims_end *= input_shape[i];
-                    }
-                }
-                // The batch dimenstion is repositioned in the shape
-                // only in case of batch dimension reduction
-                shape_begin.assign({dims_begin, 1, dims_prod, dims_end});
+            if (!reduce->get_keep_dims()) {
                 shape_end = reduce->output(0).get_shape();
-                strides.assign({1, 1});
-                pads_begin.assign({0, 0});
-                pads_end.assign({0, 0});
-                kernel.assign({dims_prod, 1});
-            } else {
-                for (size_t i = 0; i < input_shape.size() - 2; ++i) {
-                    strides.push_back(1);
-                    pads_begin.push_back(0);
-                    pads_end.push_back(0);
-                    kernel.push_back(1);
-                }
-                for (auto& axis : axes_vector) {
-                    kernel[axis-2] = input_shape[axis];
-                }
-                if (!reduce->get_keep_dims()) {
-                    shape_end = reduce->output(0).get_shape();
-                }
-            }
-
-            /*
-             *  ReduceMean => AvgPool
-             *                AvgPool->Reshape (in case if keep_dims=False)
-             *                Reshape->AvgPool->Reshape (in case if axes doesn't match spatial dims)
-
-             *  ReduceMax  => MaxPool
-             *                MaxPool->Reshape (in case if keep_dims=False)
-             *                Reshape->MaxPool->Reshape (in case if axes doesn't match spatial dims)
-             *
-             *  ReduceSum  => AvgPool->Multiply
-             *                AvgPool->Multiply->Reshape (in case if keep_dims=False)
-             *                Reshape->AvgPool->Multiply->Reshape (in case if axes doesn't match spatial dims)
-             *
-             *  Note: some of reshape nodes can be optimized if they do nothing.
-             */
-            NodeVector new_ops;
-
-            if (!shape_begin.empty() && shape_begin != input.get_shape()) {
-                input = std::make_shared<ngraph::opset1::Reshape>(input, opset1::Constant::create(element::i64, Shape{shape_begin.size()}, shape_begin), true);
-                input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/reshape_begin");
-                new_ops.push_back(input.get_node_shared_ptr());
             }
-
-            if (std::is_same<T, ngraph::opset1::ReduceMean>()) {
-                input = std::make_shared<ngraph::opset1::AvgPool>(input,
-                                                                  strides,
-                                                                  pads_begin,
-                                                                  pads_end,
-                                                                  kernel,
-                                                                  true,
-                                                                  op::RoundingType::FLOOR);
-
-                input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
-                new_ops.push_back(input.get_node_shared_ptr());
-            } else if (std::is_same<T, ngraph::opset1::ReduceMax>()) {
-                input = std::make_shared<ngraph::opset1::MaxPool>(input,
-                                                                  strides,
-                                                                  pads_begin,
-                                                                  pads_end,
-                                                                  kernel,
-                                                                  op::RoundingType::FLOOR);
-
-                input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
-                new_ops.push_back(input.get_node_shared_ptr());
-            } else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
-                input = std::make_shared<ngraph::opset1::AvgPool>(input,
-                                                                  strides,
-                                                                  pads_begin,
-                                                                  pads_end,
-                                                                  kernel,
-                                                                  true,
-                                                                  op::RoundingType::FLOOR);
-
-                input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
-                new_ops.push_back(input.get_node_shared_ptr());
-
-                input = std::make_shared<ngraph::opset1::Multiply>(input,
-                        opset1::Constant::create(reduce->input(0).get_element_type(), Shape{1}, {reduction_dims_count}));
-                input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
-                new_ops.push_back(input.get_node_shared_ptr());
-            } else {
-                return false;
-            }
-
-            if (!shape_end.empty() && shape_end != input.get_shape()) {
-                input = std::make_shared<ngraph::opset1::Reshape>(input, opset1::Constant::create(element::i64, Shape{shape_end.size()}, shape_end), true);
-                new_ops.push_back(input.get_node_shared_ptr());
-            }
-            input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
-            copy_runtime_info(reduce, new_ops);
-            reduce->output(0).replace(input);
-            return true;
-        };
-
-        auto m = std::make_shared<ngraph::pattern::Matcher>(reduce, "ConvertReduceToPooling");
-        this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
-    }
+        }
+
+        /*
+         *  ReduceMean => AvgPool
+         *                AvgPool->Reshape (in case if keep_dims=False)
+         *                Reshape->AvgPool->Reshape (in case if axes doesn't match spatial dims)
+
+         *  ReduceMax  => MaxPool
+         *                MaxPool->Reshape (in case if keep_dims=False)
+         *                Reshape->MaxPool->Reshape (in case if axes doesn't match spatial dims)
+         *
+         *  ReduceSum  => AvgPool->Multiply
+         *                AvgPool->Multiply->Reshape (in case if keep_dims=False)
+         *                Reshape->AvgPool->Multiply->Reshape (in case if axes doesn't match spatial dims)
+         *
+         *  Note: some of reshape nodes can be optimized if they do nothing.
+         */
+        ngraph::NodeVector new_ops;
+
+        if (!shape_begin.empty() && shape_begin != input.get_shape()) {
+            input = std::make_shared<ngraph::opset1::Reshape>(input,
+                    ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_begin.size()}, shape_begin), true);
+            input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/reshape_begin");
+            new_ops.push_back(input.get_node_shared_ptr());
+        }
+
+        if (std::is_same<T, ngraph::opset1::ReduceMean>()) {
+            input = std::make_shared<ngraph::opset1::AvgPool>(input,
+                                                              strides,
+                                                              pads_begin,
+                                                              pads_end,
+                                                              kernel,
+                                                              true,
+                                                              ngraph::op::RoundingType::FLOOR);
+
+            input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
+            new_ops.push_back(input.get_node_shared_ptr());
+        } else if (std::is_same<T, ngraph::opset1::ReduceMax>()) {
+            input = std::make_shared<ngraph::opset1::MaxPool>(input,
+                                                              strides,
+                                                              pads_begin,
+                                                              pads_end,
+                                                              kernel,
+                                                              ngraph::op::RoundingType::FLOOR);
+
+            input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
+            new_ops.push_back(input.get_node_shared_ptr());
+        } else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
+            input = std::make_shared<ngraph::opset1::AvgPool>(input,
+                                                              strides,
+                                                              pads_begin,
+                                                              pads_end,
+                                                              kernel,
+                                                              true,
+                                                              ngraph::op::RoundingType::FLOOR);
+
+            input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
+            new_ops.push_back(input.get_node_shared_ptr());
+
+            input = std::make_shared<ngraph::opset1::Multiply>(input,
+                    ngraph::opset1::Constant::create(reduce->input(0).get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
+            input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
+            new_ops.push_back(input.get_node_shared_ptr());
+        } else {
+            return false;
+        }
+
+        if (!shape_end.empty() && shape_end != input.get_shape()) {
+            input = std::make_shared<ngraph::opset1::Reshape>(input,
+                    ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_end.size()}, shape_end), true);
+            new_ops.push_back(input.get_node_shared_ptr());
+        }
+        input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
+        copy_runtime_info(reduce, new_ops);
+        reduce->output(0).replace(input);
+        return true;
+    };
 }
index 6c335c6..8549068 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,9 +19,9 @@ class TRANSFORMATIONS_API ConvertSpaceToBatch;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertSpaceToBatch: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam  {
+class ngraph::pass::ConvertSpaceToBatch: public ngraph::pass::GraphRewrite {
 public:
-    ConvertSpaceToBatch() : GraphRewrite(), PassParam() {
+    ConvertSpaceToBatch() : GraphRewrite() {
         // convert_space_to_batch();
         convert_space_to_batch_by_elements();
     }
index 2c8a791..91a19eb 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -20,12 +19,7 @@ class TRANSFORMATIONS_API ConvertSpaceToDepth;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam  {
+class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::MatcherPass {
 public:
-    ConvertSpaceToDepth() : GraphRewrite(), PassParam() {
-        convert();
-    }
-
-private:
-    void convert();
+    ConvertSpaceToDepth();
 };
index 3b204f6..46bb6b7 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API ConvertSubtract;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::ConvertSubtract: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSubtract: public ngraph::pass::MatcherPass {
 public:
-    ConvertSubtract() : GraphRewrite() {
-        convert_subtract();
-    }
-
-private:
-    void convert_subtract();
+    ConvertSubtract();
 };
index f0c30b2..8526fc4 100644 (file)
@@ -10,7 +10,6 @@
 #include <transformations_visibility.hpp>
 
 #include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
 
 namespace ngraph {
 namespace pass {
@@ -41,9 +40,10 @@ namespace pass {
  *     p.run_on_function(f);
  *
  */
-class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+
+class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite {
 public:
-    DepthToSpaceFusion() : GraphRewrite(), PassParam() {
+    DepthToSpaceFusion() : GraphRewrite() {
         depth_to_space_fusion();
     }
 
diff --git a/inference-engine/src/transformations/include/transformations/lin_op_sequence_fusoin.hpp b/inference-engine/src/transformations/include/transformations/lin_op_sequence_fusoin.hpp
new file mode 100644 (file)
index 0000000..5451747
--- /dev/null
@@ -0,0 +1,46 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API LinOpSequenceFusion;
+class TRANSFORMATIONS_API AddMultiplyFusion;
+class TRANSFORMATIONS_API AddAddFusion;
+class TRANSFORMATIONS_API MultiplyMultiplyFusion;
+
+}  // namespace pass
+}  // namespace ngraph
+
+class ngraph::pass::LinOpSequenceFusion: public ngraph::pass::GraphRewrite {
+public:
+    LinOpSequenceFusion() {
+        add_matcher<ngraph::pass::AddMultiplyFusion>();
+        add_matcher<ngraph::pass::AddAddFusion>();
+        add_matcher<ngraph::pass::MultiplyMultiplyFusion>();
+    }
+};
+
+class ngraph::pass::AddMultiplyFusion: public ngraph::pass::MatcherPass {
+public:
+    AddMultiplyFusion();
+};
+
+class ngraph::pass::AddAddFusion: public ngraph::pass::MatcherPass {
+public:
+    AddAddFusion();
+};
+
+class ngraph::pass::MultiplyMultiplyFusion: public ngraph::pass::MatcherPass {
+public:
+    MultiplyMultiplyFusion();
+};
index a98921a..40030a8 100644 (file)
@@ -19,12 +19,7 @@ class TRANSFORMATIONS_API PullTransposeThroughFQUp;
 }  // namespace pass
 }  // namespace ngraph
 
-class ngraph::pass::PullTransposeThroughFQUp: public ngraph::pass::GraphRewrite {
+class ngraph::pass::PullTransposeThroughFQUp: public ngraph::pass::MatcherPass {
 public:
-    PullTransposeThroughFQUp() : GraphRewrite() {
-        pull_transpose_through_fq();
-    }
-
-private:
-    void pull_transpose_through_fq();
+    PullTransposeThroughFQUp();
 };
diff --git a/inference-engine/src/transformations/include/transformations/utils/pass_param.hpp b/inference-engine/src/transformations/include/transformations/utils/pass_param.hpp
deleted file mode 100644 (file)
index 6b2854d..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#pragma once
-
-#include <memory>
-#include <transformations_visibility.hpp>
-#include <ngraph/pass/graph_rewrite.hpp>
-
-namespace ngraph {
-namespace pass {
-
-class TRANSFORMATIONS_API PassParam;
-
-}  // namespace pass
-}  // namespace ngraph
-
-class ngraph::pass::PassParam {
-public:
-    using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
-
-    explicit PassParam(const param_callback & callback = getDefaultCallback()) : transformation_callback(callback) {}
-
-    void setCallback(const param_callback & callback) {
-        transformation_callback = callback;
-    }
-
-    static param_callback getDefaultCallback() {
-        return [](const std::shared_ptr<const Node> &) -> bool {
-            return false;
-        };
-    }
-
-protected:
-    param_callback transformation_callback;
-};
index c5b7b96..a627fa2 100644 (file)
@@ -10,7 +10,7 @@
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::BatchNormDecomposition::batch_norm_decomposition() {
+ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
     Shape shape{2, 2, 1, 1};
     auto input = make_shared<pattern::op::Label>(element::f32, shape);
     auto mean_shape = Shape{2};
@@ -85,5 +85,5 @@ void ngraph::pass::BatchNormDecomposition::batch_norm_decomposition() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "BatchNormDecomposition");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index ef73766..e236bfc 100644 (file)
 
 
 bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
-    ngraph::pass::Manager CommonOptimizations;
-    std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
+    ngraph::pass::Manager manager;
 
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(CommonOptimizations.register_pass<NAMESPACE::NAME>());
-#include <transformations/common_optimizations/common_optimizations_tbl.hpp>
-#undef NGRAPH_PASS
+    // This pass must be called first in pipeline
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+    manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
+    manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
+    manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+    manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
+    manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
 
-    for (auto & t : transforms) {
-        if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
-            t_param->setCallback(transformation_callback);
-        }
-    }
-    CommonOptimizations.run_passes(f);
+    manager.set_callback(m_transformation_callback);
+    manager.run_passes(f);
     return true;
 }
diff --git a/inference-engine/src/transformations/src/transformations/constant_eltwise_reduction.cpp b/inference-engine/src/transformations/src/transformations/constant_eltwise_reduction.cpp
deleted file mode 100644 (file)
index fa27c45..0000000
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright (C) 2018-2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#include "transformations/constant_eltwise_reduction.hpp"
-
-#include <memory>
-#include <vector>
-
-#include <ngraph/opsets/opset1.hpp>
-
-template <typename T>
-std::shared_ptr<ngraph::op::Constant> constant_reduction(const std::shared_ptr<ngraph::op::Constant>& const_node) {
-    std::vector<T> data = const_node->get_vector<T>();
-    // TODO: implement this function after eltwise broadcast support will be added
-    return nullptr;
-}
-
-ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
-    // Check that eltwise operation either Add or Multiply
-    auto eltwise_node = m.get_match_root();
-    if (!std::dynamic_pointer_cast<ngraph::opset1::Add>(eltwise_node) &&
-        !std::dynamic_pointer_cast<ngraph::opset1::Multiply>(eltwise_node)) {
-        return false;
-    }
-
-    for (const auto& input : eltwise_node->inputs()) {
-        const auto& inputLayer = input.get_source_output().get_node_shared_ptr();
-        auto const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(inputLayer);
-        if (!const_node) continue;
-
-        std::shared_ptr<ngraph::opset1::Constant> result = nullptr;
-
-        ngraph::element::Type elem_type = const_node->get_element_type();
-        switch (elem_type) {
-        case ngraph::element::Type_t::f32:
-            result = constant_reduction<float>(const_node);
-            break;
-        case ngraph::element::Type_t::i16:
-        case ngraph::element::Type_t::f16:
-            result = constant_reduction<short>(const_node);
-            break;
-        case ngraph::element::Type_t::u8:
-            result = constant_reduction<uint8_t>(const_node);
-            break;
-        case ngraph::element::Type_t::i8:
-            result = constant_reduction<int8_t>(const_node);
-            break;
-        case ngraph::element::Type_t::i32:
-            result = constant_reduction<int32_t>(const_node);
-            break;
-        default:
-            return false;
-        }
-        if (result) {
-            ngraph::replace_node(inputLayer, std::dynamic_pointer_cast<ngraph::Node>(result));
-            std::cout << "Successful constant_eltwise reduction" << std::endl;
-        }
-    }
-
-    return true;
-};
-
-void ngraph::pass::ConstantEltwiseReduction::constant_multiply_reduction() {
-    Shape shape {2, 2, 1, 1};
-    auto constant1 = std::make_shared<pattern::op::Label>(element::f32, shape);
-    auto constant2 = std::make_shared<pattern::op::Label>(element::f32, shape);
-    auto eltwise = std::make_shared<ngraph::opset1::Multiply>(constant1, constant2);
-
-    auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "CPUFusion.ConstantMultiplyReduction");
-    this->add_matcher(m, callback);
-}
-
-void ngraph::pass::ConstantEltwiseReduction::constant_add_reduction() {
-    Shape shape {2, 2, 1, 1};
-    auto constant1 = std::make_shared<pattern::op::Label>(element::f32, shape);
-    auto constant2 = std::make_shared<pattern::op::Label>(element::f32, shape);
-    auto eltwise = std::make_shared<ngraph::opset1::Add>(constant1, constant2);
-
-    auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "CPUFusion.ConstantAddReduction");
-    this->add_matcher(m, callback);
-}
index 9fc17ce..b8f1dfd 100644 (file)
@@ -48,7 +48,7 @@ void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space_ie_side() {
         auto data = batch_to_space->input_value(0);
         auto data_shape = data.get_shape();
 
-        if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
+        if (m_transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
             return false;
         }
         auto block = batch_to_space->input_value(1);
index 5d263d6..a06be51 100644 (file)
@@ -9,12 +9,10 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertBroadcastToTiles::convert_broadcast_to_tiles() {
-    auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape {1});
-    auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape {1});
-    auto axs = std::make_shared<pattern::op::Label>(element::i64, Shape {1});
-    auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(weights, shp, axs);
+ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
+    auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
         auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
@@ -94,5 +92,5 @@ void ngraph::pass::ConvertBroadcastToTiles::convert_broadcast_to_tiles() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(broadcast, "ConvertBroadcastToTile");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index 959b459..cf875d0 100644 (file)
@@ -9,14 +9,14 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() {
-    auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto dts_node = std::make_shared<ngraph::opset1::DepthToSpace>(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST);
+ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
+    auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>();
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
-        if (!dts_node || transformation_callback(dts_node)) {
+        if (!dts_node || m_transformation_callback(dts_node)) {
             return false;
         }
 
@@ -99,5 +99,5 @@ void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(dts_node, "ConvertDepthToSpace");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index 22e7b64..04f55b6 100644 (file)
@@ -9,11 +9,10 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertDivide::convert_divide() {
-    auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto div = std::make_shared<ngraph::opset1::Divide>(input0, input1);
+ngraph::pass::ConvertDivide::ConvertDivide() {
+    auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>();
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
         auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide> (m.get_match_root());
@@ -34,5 +33,5 @@ void ngraph::pass::ConvertDivide::convert_divide() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 5e82d76..65e0d67 100644 (file)
@@ -9,11 +9,10 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertMinimum::convert_minimum() {
-    auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto minimum = std::make_shared<ngraph::opset1::Minimum>(input0, input1);
+ngraph::pass::ConvertMinimum::ConvertMinimum() {
+    auto minimum = ngraph::pattern::wrap_type<opset1::Minimum>();
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
         auto minimum = std::dynamic_pointer_cast<ngraph::opset1::Minimum> (m.get_match_root());
@@ -43,5 +42,5 @@ void ngraph::pass::ConvertMinimum::convert_minimum() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(minimum, "ConvertMinimum");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index fb0d3f0..02570eb 100644 (file)
@@ -9,13 +9,12 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertMod::convert_mod() {
-    auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto mod = std::make_shared<ngraph::opset1::Mod>(input0, input1);
+ngraph::pass::ConvertMod::ConvertMod() {
+    auto mod = ngraph::pattern::wrap_type<opset1::Mod>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
         auto mod = std::dynamic_pointer_cast<ngraph::opset1::Mod> (m.get_match_root());
         if (!mod) {
             return false;
@@ -27,13 +26,13 @@ void ngraph::pass::ConvertMod::convert_mod() {
         const auto divisor = std::make_shared<opset1::Abs>(mod->input_value(1));
 
         // truncated(a / b)
-        auto div = std::make_shared<opset1::Divide>(dividend, divisor);
+        auto div = register_new_node<opset1::Divide>(dividend, divisor);
         auto convert_to_i64 = std::make_shared<opset1::Convert>(div, ngraph::element::i64);
         auto convert = std::make_shared<opset1::Convert>(convert_to_i64, dividend_et);
         // truncated(a / b) * b
         auto multiplication = std::make_shared<opset1::Multiply>(convert, divisor);
         // a mod b = a - truncated(a / b) * b
-        auto sub = std::make_shared<opset1::Subtract>(dividend, multiplication);
+        auto sub = register_new_node<opset1::Subtract>(dividend, multiplication);
 
         // apply sign of dividend
         auto mul = std::make_shared<opset1::Multiply>(dividend_sign, sub);
@@ -45,5 +44,5 @@ void ngraph::pass::ConvertMod::convert_mod() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(mod, "ConvertMod");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
 }
\ No newline at end of file
index 9eb4956..3a55de1 100644 (file)
@@ -9,10 +9,10 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertNegative::convert_negative() {
-    auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
-    auto neg = std::make_shared<ngraph::opset1::Negative>(input);
+ngraph::pass::ConvertNegative::ConvertNegative() {
+    auto neg = ngraph::pattern::wrap_type<ngraph::opset1::Negative>();
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
         auto neg = std::dynamic_pointer_cast<ngraph::opset1::Negative> (m.get_match_root());
@@ -29,5 +29,5 @@ void ngraph::pass::ConvertNegative::convert_negative() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(neg, "ConvertNegative");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/conv_bias_fusion.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/conv_bias_fusion.cpp
new file mode 100644 (file)
index 0000000..4a67b6a
--- /dev/null
@@ -0,0 +1,123 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_opset1_to_legacy/conv_bias_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+#include <ngraph_ops/convolution_ie.hpp>
+#include <ngraph_ops/deconvolution_ie.hpp>
+
+template <class Conv>
+ngraph::graph_rewrite_callback get_callback() {
+    ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
+        auto eltwise = m.get_match_root();
+
+        std::shared_ptr<ngraph::opset1::Constant> m_const;
+        std::shared_ptr<Conv> m_conv;
+        // FIXME: use auto [m_conv, m_const] when C++17 is available
+        std::tie(m_conv, m_const) = parse_eltwise_inputs<Conv, ngraph::opset1::Constant>(eltwise);
+        if (!m_conv || !m_const) {
+            return false;
+        }
+
+        // TODO: check that constant can be scalar and do not match [1, C, 1, 1] layout
+        const auto constant_shape = m_const->get_shape();
+        const auto output_pshape = m_conv->get_output_partial_shape(0);
+
+        if (output_pshape.rank().is_dynamic() || output_pshape[1].is_dynamic()) {
+            return false;
+        }
+
+        const auto channel_dim = output_pshape[1].get_length();
+
+        size_t constant_size = std::accumulate(constant_shape.begin(), constant_shape.end(), 1, std::multiplies<size_t>());
+        if (constant_size != channel_dim) {
+            return false;
+        }
+
+        ngraph::Output<ngraph::Node> constant(m_const);
+
+        if (constant_shape.size() > 1) {
+            constant = std::make_shared<ngraph::opset1::Reshape>(constant,
+                    ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {channel_dim}), true);
+        }
+
+        if (m_conv->output(0).get_target_inputs().size() != 1) {
+            return false;
+        }
+
+        ngraph::Output<ngraph::Node> new_conv, new_weights, new_bias;
+        if (std::dynamic_pointer_cast<ngraph::opset1::Add>(eltwise)) {
+            // Fuse: ConvolutionIE/DeconvolutionIE->Add
+            if (m_conv->inputs().size() == 2) {
+                new_bias = constant;
+            } else {
+                new_bias = std::make_shared<ngraph::opset1::Add>(constant, m_conv->input_value(2));
+            }
+            new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), m_conv->input_value(1), new_bias});
+        } else if (std::is_same<Conv, ngraph::op::ConvolutionIE>() && std::dynamic_pointer_cast<ngraph::opset1::Multiply>(eltwise)) {
+            // Fuse: ConvolutionIE->Mul
+            auto weights_shape = m_conv->input(1).get_shape();
+
+            ngraph::Shape const_shape(weights_shape.size(), 1);
+            const_shape[0] = weights_shape[0];
+
+            auto const_reshape = std::make_shared<ngraph::opset1::Reshape>(constant,
+                    ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape), true);
+            new_weights = std::make_shared<ngraph::opset1::Multiply> (m_conv->input_value(1), const_reshape);
+            if (m_conv->inputs().size() == 2) {
+                new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights});
+            } else {
+                auto bias_reshape = std::make_shared<ngraph::opset1::Reshape>(constant,
+                        ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {weights_shape[0]}), true);
+                new_bias = std::make_shared<ngraph::opset1::Multiply>(bias_reshape, constant);
+                new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights, new_bias});
+            }
+        } else {
+            return false;
+        }
+
+        ngraph::copy_runtime_info({m_conv, eltwise}, new_conv.get_node_shared_ptr());
+        new_conv.get_node_shared_ptr()->set_friendly_name(m.get_match_root()->get_friendly_name());
+        ngraph::replace_node(m.get_match_root(), new_conv.get_node_shared_ptr());
+        return true;
+    };
+    return callback;
+}
+
+ngraph::pass::ConvAddFusion::ConvAddFusion() {
+    auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>();
+    auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
+
+    matcher_pass_callback callback = get_callback<op::ConvolutionIE>();
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(add, "ConvAddFusion");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::ConvMultiplyFusion::ConvMultiplyFusion() {
+    auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>();
+    auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, std::make_shared<pattern::op::Label>()});
+
+    matcher_pass_callback callback = get_callback<op::ConvolutionIE>();
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(add, "ConvMultiplyFusion");
+    register_matcher(m, callback);
+}
+
+ngraph::pass::DeconvAddFusion::DeconvAddFusion() {
+    auto conv = ngraph::pattern::wrap_type<op::DeconvolutionIE>();
+    auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
+
+    matcher_pass_callback callback = get_callback<op::DeconvolutionIE>();
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(add, "DeconvAddFusion");
+    register_matcher(m, callback);
+}
\ No newline at end of file
index 415b992..91960c9 100644 (file)
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/opsets/opset3.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 #include <ngraph_ops/lstm_cell_ie.hpp>
 #include <ngraph_ops/gru_cell_ie.hpp>
 #include <ngraph_ops/rnn_cell_ie.hpp>
 
-void ngraph::pass::ConvertCellsToCellsIE::convert_lstm_cell() {
-    // placeholders
-    auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});  // X
-    auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // initial_hidden_state
-    auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // initial_cell_state
-    auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1});  // W
-    auto input_4 = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1});  // R
-    auto input_5 = std::make_shared<pattern::op::Label>(element::f32, Shape{4});     // B
+ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() {
+    auto lstm_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset1::LSTMCell>();
 
-    auto lstm_cell_ngraph = std::make_shared<ngraph::opset1::LSTMCell>(input_0, input_1, input_2, input_3, input_4, input_5, 1);
-
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto lstm_cell = std::dynamic_pointer_cast<ngraph::opset1::LSTMCell> (m.get_match_root());
         if (!lstm_cell) {
             return false;
@@ -61,20 +54,13 @@ void ngraph::pass::ConvertCellsToCellsIE::convert_lstm_cell() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_cell_ngraph, "ConvertLSTMCellToLSTMCellIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-void ngraph::pass::ConvertCellsToCellsIE::convert_gru_cell() {
-    // placeholders
-    auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // X
-    auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // initial_hidden_state
-    auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 1});  // W
-    auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 1});  // R
-    auto input_4 = std::make_shared<pattern::op::Label>(element::f32, Shape{3});     // B
-
-    auto gru_cell_ngraph = std::make_shared<ngraph::opset3::GRUCell>(input_0, input_1, input_2, input_3, input_4, 1);
+ngraph::pass::ConvertGRUCellMatcher::ConvertGRUCellMatcher() {
+    auto gru_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset3::GRUCell>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto gru_cell = std::dynamic_pointer_cast<ngraph::opset3::GRUCell> (m.get_match_root());
         if (!gru_cell) {
             return false;
@@ -109,20 +95,13 @@ void ngraph::pass::ConvertCellsToCellsIE::convert_gru_cell() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(gru_cell_ngraph, "ConvertGRUCellToGRUCellIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-void ngraph::pass::ConvertCellsToCellsIE::convert_rnn_cell() {
-    // placeholders
-    auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // X
-    auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // initial_hidden_state
-    auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // W
-    auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1});  // R
-    auto input_4 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});     // B
-
-    auto rnn_cell_ngraph = std::make_shared<ngraph::opset3::RNNCell>(input_0, input_1, input_2, input_3, input_4, 1);
+ngraph::pass::ConvertRNNCellMatcher::ConvertRNNCellMatcher() {
+    auto rnn_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset3::RNNCell>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto rnn_cell = std::dynamic_pointer_cast<ngraph::opset3::RNNCell> (m.get_match_root());
         if (!rnn_cell) {
             return false;
@@ -156,5 +135,5 @@ void ngraph::pass::ConvertCellsToCellsIE::convert_rnn_cell() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_cell_ngraph, "ConvertRNNCellToRNNCellIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 1d853aa..fda158c 100644 (file)
 #include <ngraph_ops/convolution_ie.hpp>
 #include <ngraph_ops/deconvolution_ie.hpp>
 
-void ngraph::pass::ConvertConvolutions::convert_convolution() {
-    auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
-            pattern::has_class<opset1::Convolution>());
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ngraph::pass::ConvertConvolution::ConvertConvolution() {
+    auto conv = ngraph::pattern::wrap_type<opset1::Convolution>();
+
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto conv = std::dynamic_pointer_cast<ngraph::opset1::Convolution> (m.get_match_root());
         if (!conv) {
             return false;
@@ -38,14 +39,13 @@ void ngraph::pass::ConvertConvolutions::convert_convolution() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvertConvolution");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-void ngraph::pass::ConvertConvolutions::convert_group_convolution() {
-    auto gconv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
-            pattern::has_class<opset1::GroupConvolution>());
+ngraph::pass::ConvertGroupConvolution::ConvertGroupConvolution() {
+    auto gconv = ngraph::pattern::wrap_type<opset1::GroupConvolution>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto gconv = std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution> (m.get_match_root());
         if (!gconv) {
             return false;
@@ -81,14 +81,13 @@ void ngraph::pass::ConvertConvolutions::convert_group_convolution() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ConvertGroupConvolution");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-void ngraph::pass::ConvertConvolutions::convert_convolution_backprop_data() {
-    auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
-            pattern::has_class<opset1::ConvolutionBackpropData>());
+ngraph::pass::ConvertDeconvolution::ConvertDeconvolution() {
+    auto conv = ngraph::pattern::wrap_type<opset1::ConvolutionBackpropData>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto deconv = std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData> (m.get_match_root());
         if (!deconv) {
             return false;
@@ -112,14 +111,13 @@ void ngraph::pass::ConvertConvolutions::convert_convolution_backprop_data() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvertConvolutionBackpropData");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-void ngraph::pass::ConvertConvolutions::convert_group_convolution_backprop_data() {
-    auto gconv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
-            pattern::has_class<opset1::GroupConvolutionBackpropData>());
+ngraph::pass::ConvertGroupDeconvolution::ConvertGroupDeconvolution() {
+    auto gconv = ngraph::pattern::wrap_type<opset1::GroupConvolutionBackpropData>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto gconv = std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData> (m.get_match_root());
         if (!gconv) {
             return false;
@@ -154,5 +152,5 @@ void ngraph::pass::ConvertConvolutions::convert_group_convolution_backprop_data(
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ConvertGroupConvolutionBackpropData");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index 9ee227a..2d8ccc6 100644 (file)
@@ -9,11 +9,12 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertGatherToGatherIE::convert_gather_to_gather_ie() {
-    auto gather = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::Gather>());
+ngraph::pass::ConvertGatherToGatherIEMatcher::ConvertGatherToGatherIEMatcher() {
+    auto gather = ngraph::pattern::wrap_type<opset1::Gather>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
         auto gather = std::dynamic_pointer_cast<ngraph::opset1::Gather>(m.get_match_root());
         if (!gather) {
             return false;
@@ -63,5 +64,5 @@ void ngraph::pass::ConvertGatherToGatherIE::convert_gather_to_gather_ie() {
     };
 
     auto m1 = std::make_shared<ngraph::pattern::Matcher>(gather, "ConvertGatherToGatherIE");
-    this->add_matcher(m1, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m1, callback);
 }
index a8db453..0f81993 100644 (file)
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertGatherTreeToGatherTreeIE::convert() {
+ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher::ConvertGatherTreeToGatherTreeIEMatcher() {
     auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1});
     auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1});
     auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
     auto input3 = std::make_shared<pattern::op::Label>(element::i64, Shape{});
     auto gt = std::make_shared<ngraph::opset1::GatherTree>(input0, input1, input2, input3);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto gt = std::dynamic_pointer_cast<ngraph::opset1::GatherTree> (m.get_match_root());
         if (!gt) {
             return false;
@@ -34,5 +34,5 @@ void ngraph::pass::ConvertGatherTreeToGatherTreeIE::convert() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(gt, "ConvertGatherTreeToGatherTreeIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index dc2b9f2..4414136 100644 (file)
@@ -16,7 +16,7 @@ void ngraph::pass::ConvertGELU::convert_gelu() {
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());
-        if (!gelu || transformation_callback(gelu))
+        if (!gelu || m_transformation_callback(gelu))
             return false;
         auto input = gelu->input_value(0);
         auto input_type = input.get_element_type();
index 726498a..a585422 100644 (file)
 #include <ngraph_ops/hard_sigmoid_ie.hpp>
 
 
-void ngraph::pass::ConvertHardSigmoidToHardSigmoidIE::convert_hard_sigmoid() {
+ngraph::pass::ConvertHardSigmoidToLegacyMatcher::ConvertHardSigmoidToLegacyMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
     auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
     auto node = std::make_shared<ngraph::opset1::HardSigmoid>(input_0, input_1, input_2);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto hard_sigmoid = std::dynamic_pointer_cast<ngraph::opset1::HardSigmoid> (m.get_match_root());
         if (!hard_sigmoid) {
             return false;
@@ -51,6 +51,6 @@ void ngraph::pass::ConvertHardSigmoidToHardSigmoidIE::convert_hard_sigmoid() {
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(node, "ConvertHardSigmoidToHardSigmoidIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<ngraph::pattern::Matcher>(node, "ConvertHardSigmoidToLegacy");
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 767c434..f0ea3d1 100644 (file)
 
 #include <ngraph_ops/interp.hpp>
 
-void ngraph::pass::ConvertInterpolateToInterpOrResample::convert_interpolate_to_interp_or_resample() {
+ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher::ConvertInterpolateToInterpOrResampleMatcher() {
     auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
     auto interpolate = std::make_shared<ngraph::opset1::Interpolate>(data, shp, ngraph::op::InterpolateAttrs());
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto interpolate = std::dynamic_pointer_cast<ngraph::opset1::Interpolate> (m.get_match_root());
         if (!interpolate)
             return false;
@@ -159,5 +159,5 @@ void ngraph::pass::ConvertInterpolateToInterpOrResample::convert_interpolate_to_
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(interpolate, "ConvertInterpolateToInterpOrResample");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 9583dd2..e12597a 100644 (file)
 
 #include <ngraph_ops/lrn_ie.hpp>
 
-void ngraph::pass::ConvertLRNToLRNIE::convert_lrn() {
+ngraph::pass::ConvertLRNToLegacyMatcher::ConvertLRNToLegacyMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
     auto lrn = std::make_shared<ngraph::opset1::LRN>(input_0, input_1, 1, 1, 1, 1);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto lrn = std::dynamic_pointer_cast<ngraph::opset1::LRN> (m.get_match_root());
         if (!lrn) {
             return false;
@@ -62,6 +62,6 @@ void ngraph::pass::ConvertLRNToLRNIE::convert_lrn() {
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(lrn, "ConvertLRNToLRNIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<ngraph::pattern::Matcher>(lrn, "ConvertLRNToLegacy");
+    this->register_matcher(m, callback);
 }
index f4d7d9d..1ad6616 100644 (file)
 #include <transformations/utils/utils.hpp>
 
 
-void ngraph::pass::ConvertMatMulToFCorGemm::convert_matmul() {
+ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
     auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
         auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
         if (!matmul) {
             return false;
@@ -79,14 +79,14 @@ void ngraph::pass::ConvertMatMulToFCorGemm::convert_matmul() {
          *  order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute.
          */
 
-        auto create_transpose = [](Output<Node> node, const std::string& transpose_name) -> std::shared_ptr<Node> {
+        auto create_transpose = [this](Output<Node> node, const std::string& transpose_name) -> std::shared_ptr<Node> {
             Shape output_shape = node.get_node_shared_ptr()->get_shape();
 
             std::vector<size_t> transpose_order(output_shape.size());
             std::iota(transpose_order.begin(), transpose_order.end(), 0);
             std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
 
-            auto transpose = std::make_shared<ngraph::opset1::Transpose>(
+            auto transpose = register_new_node<ngraph::opset1::Transpose>(
                     node, opset1::Constant::create(element::i64, Shape {transpose_order.size()}, transpose_order));
             transpose->set_friendly_name(transpose_name);
             return transpose;
@@ -199,5 +199,5 @@ void ngraph::pass::ConvertMatMulToFCorGemm::convert_matmul() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFCorGemm");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index 65cf67e..d422b10 100644 (file)
@@ -14,7 +14,7 @@
 
 #include "transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.hpp"
 
-void ngraph::pass::ConvertNMS4ToLegacy::convert_nms4_to_legacy() {
+ngraph::pass::ConvertNMS4ToLegacyMatcher::ConvertNMS4ToLegacyMatcher() {
     auto boxes = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1000, 4});
     auto scores = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1000});
     auto max_output_boxes_per_class = ngraph::opset4::Constant::create(element::i64, Shape{}, {10});
@@ -23,7 +23,7 @@ void ngraph::pass::ConvertNMS4ToLegacy::convert_nms4_to_legacy() {
     auto nms = std::make_shared<ngraph::opset4::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class,
                                                                    iou_threshold, score_threshold);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
         auto nms_4 = std::dynamic_pointer_cast<ngraph::opset4::NonMaxSuppression>(m.get_match_root());
         if (!nms_4) {
             return false;
@@ -115,5 +115,5 @@ void ngraph::pass::ConvertNMS4ToLegacy::convert_nms4_to_legacy() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMS4ToNMSLegacy");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index 5b2e755..588bcb9 100644 (file)
 
 #include <ngraph_ops/nms_ie.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertNMSToNMSIE::convert_nms_to_nms_ie() {
-    auto nms = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset1::NonMaxSuppression>());
+ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() {
+    auto nms = ngraph::pattern::wrap_type<opset1::NonMaxSuppression>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
         auto nms = std::dynamic_pointer_cast<opset1::NonMaxSuppression>(m.get_match_root());
         if (!nms) {
             return false;
@@ -95,5 +96,5 @@ void ngraph::pass::ConvertNMSToNMSIE::convert_nms_to_nms_ie() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMSToNMSIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 0a71480..cf2511a 100644 (file)
@@ -12,7 +12,7 @@
 
 #include "ngraph_ops/normalize_ie.hpp"
 
-void ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::convert_normalize_l2_with_mul() {
+ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::ConvertNormalizeL2WithMulToNormalizeIE() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto axis = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
@@ -69,15 +69,15 @@ void ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::convert_normalize_l2_
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "CPUFusion.ConvertNormalizeL2WithMulToNormalizeIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-void ngraph::pass::ConvertNormalizeL2ToNormalizeIE::convert_normalize_l2() {
+ngraph::pass::ConvertNormalizeL2ToLegacyMatcher::ConvertNormalizeL2ToLegacyMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto axis = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
     auto normalize = std::make_shared<ngraph::op::NormalizeL2>(input_0, axis, 0, ngraph::op::EpsMode::ADD);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto normalize = std::dynamic_pointer_cast<ngraph::op::NormalizeL2> (m.get_match_root());
         if (!normalize) return false;
 
@@ -103,6 +103,6 @@ void ngraph::pass::ConvertNormalizeL2ToNormalizeIE::convert_normalize_l2() {
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(normalize, "CPUFusion.ConvertNormalizeL2ToNormalizeIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<ngraph::pattern::Matcher>(normalize, "ConvertNormalizeL2ToNormalizeIE");
+    this->register_matcher(m, callback);
 }
index 79bdb1e..6d6a4c0 100644 (file)
 #include <transformations/utils/utils.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertOneHotToOneHotIE::convert_one_hot() {
+ngraph::pass::ConvertOneHotToOneHotIEMatcher::ConvertOneHotToOneHotIEMatcher() {
     auto input = std::make_shared<pattern::op::Label>(element::i32, Shape{1, 1, 1, 1});
     auto depth = std::make_shared<pattern::op::Label>(element::i64, Shape{});
     auto on_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
     auto off_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
     auto one_hot = std::make_shared<ngraph::opset1::OneHot>(input, depth, on_value, off_value, 1);
 
-    ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
         auto one_hot = std::dynamic_pointer_cast<ngraph::opset1::OneHot> (m.get_match_root());
         if (!one_hot) {
             return false;
         }
 
-        element::Type output_type = is_f16 ? element::f16 : element::f32;
-
-        const auto depth_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input(1).get_source_output().get_node_shared_ptr());
-        const auto on_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input(2).get_source_output().get_node_shared_ptr());
-        const auto off_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input(3).get_source_output().get_node_shared_ptr());
+        const auto depth_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input_value(1).get_node_shared_ptr());
+        const auto on_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input_value(2).get_node_shared_ptr());
+        const auto off_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input_value(3).get_node_shared_ptr());
 
         // can be converted iff inputs with depth, on/off values are constants
         if (depth_node == nullptr || on_value_node == nullptr || off_value_node == nullptr) return false;
@@ -40,11 +38,11 @@ void ngraph::pass::ConvertOneHotToOneHotIE::convert_one_hot() {
         auto off_value = std::stof(off_value_node->convert_value_to_string(0));
 
         auto one_hot_ie = std::make_shared<ngraph::op::OneHotIE>(one_hot->input_value(0),
-                                                                 one_hot->get_axis(), depth_value, on_value, off_value, output_type);
+                                                                 one_hot->get_axis(), depth_value, on_value, off_value, m_output_type);
         one_hot_ie->set_friendly_name(one_hot->get_friendly_name());
 
         // insert Convert layer to cast output to a correct data type defined by the on/off values
-        if (on_value_node->get_element_type() != output_type) {
+        if (on_value_node->get_element_type() != m_output_type) {
             auto convert = std::make_shared<ngraph::opset1::Convert>(one_hot_ie, on_value_node->get_element_type());
             convert->set_friendly_name(one_hot->get_friendly_name() + "/Convert");
             ngraph::copy_runtime_info(one_hot, {one_hot_ie, convert});
@@ -58,10 +56,9 @@ void ngraph::pass::ConvertOneHotToOneHotIE::convert_one_hot() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(one_hot, "ConvertOneHotToOneHotIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
-bool ngraph::pass::ConvertOneHotToOneHotIE::run_on_function(std::shared_ptr<ngraph::Function> f) {
-    is_f16 = ngraph::op::util::has_f16_constants(f);
-    return GraphRewrite::run_on_function(f);
-}
+void ngraph::pass::ConvertOneHotToOneHotIEMatcher::detect_output_type(const std::shared_ptr<ngraph::Function> &f) {
+    m_output_type  = ngraph::op::util::has_f16_constants(f) ? element::Type_t::f16 : element::Type_t::f32;
+}
\ No newline at end of file
index 76bb223..505fea0 100644 (file)
@@ -4,7 +4,6 @@
 
 #include "transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp"
 
-#include <transformations/constant_eltwise_reduction.hpp>
 #include <transformations/convert_broadcast_to_tiles.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_convolutions.hpp>
 #include <transformations/convert_divide.hpp>
 #include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
 #include <transformations/pull_transpose_through_fq.hpp>
 #include <transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.hpp>
+#include <transformations/lin_op_sequence_fusoin.hpp>
 
 #include <ngraph/pass/constant_folding.hpp>
 #include <ngraph/pass/manager.hpp>
 
+#include <ngraph/pass/graph_rewrite.hpp>
+
 #include <memory>
 #include <vector>
 
 bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph::Function> f) {
-    ngraph::pass::Manager OpSet1ToLegacy;
+    ngraph::pass::Manager manager;
     std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
 
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(OpSet1ToLegacy.register_pass<NAMESPACE::NAME>());
-#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy_tbl.hpp>
-#undef NGRAPH_PASS
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+
+    // List if Decomposition and Conversion transformations that can be
+    // applied simultaneously in a single graph traversal
+    auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
+    decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
+    decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
+    decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
+    decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
+    decomp->add_matcher<ngraph::pass::ConvertMod>();
+    decomp->add_matcher<ngraph::pass::ConvertMinimum>();
+    decomp->add_matcher<ngraph::pass::ConvertSubtract>();
+    decomp->add_matcher<ngraph::pass::ConvertDivide>();
+    decomp->add_matcher<ngraph::pass::ConvertNegative>();
+    decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
+    decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
+    decomp->add_matcher<ngraph::pass::ConvertConvolution>();
+    decomp->add_matcher<ngraph::pass::ConvertGroupConvolution>();
+    decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
+    decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
+    decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
+    decomp->add_matcher<ngraph::pass::ConvertMatMulToFCorGemm>();
+    decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
+    decomp->set_name("ngraph::pass::Decompositions");
+
+    // CF is required after all decompositions
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+
+    // LinOpSequenceFusion must be executed after all decompositions
+    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+
+    // Convolution/Deconvolution/FullyConnected fusions
+    auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
+    fusion->add_matcher<ngraph::pass::ConvAddFusion>();
+    fusion->add_matcher<ngraph::pass::ConvMultiplyFusion>();
+    fusion->add_matcher<ngraph::pass::DeconvAddFusion>();
+    fusion->add_matcher<ngraph::pass::FullyConnectedBiasFusion>();
+    fusion->set_name("ngraph::pass::Fusions");
+
+    // CF is required after fusions
+    manager.register_pass<ngraph::pass::ConstantFolding>();
+
+    // List of passes that convert opset1 operations to legacy
+    // plus transformations that are required by InferenceEngine
+    // All this transformations can be executed simultaneously
+    auto anchor = manager.register_pass<ngraph::pass::GraphRewrite>();
+    anchor->add_matcher<ngraph::pass::ReshapeFullyConnected>();
+    anchor->add_matcher<ngraph::pass::Reshape1DConvolution>();
+    anchor->add_matcher<ngraph::pass::Reshape1DAvgPool>();
+    anchor->add_matcher<ngraph::pass::Reshape1DMaxPool>();
+    anchor->add_matcher<ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE>();
+    anchor->add_matcher<ngraph::pass::ConvertHardSigmoidToLegacyMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertProposalToLegacyMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertTileToLegacyMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertLRNToLegacyMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertPadToLegacyMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertLSTMCellMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertRNNCellMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertGRUCellMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertPowerToPowerIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertSqrtToPowerIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertPReLUToReLUIE>();
+    anchor->add_matcher<ngraph::pass::ConvertGatherToGatherIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertSeluToSeluIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
+    anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+    anchor->add_matcher<ngraph::pass::ConvertNMS4ToLegacyMatcher>();
+    anchor->set_name("ngraph::pass::ConvertOpSet1ToLegacy");
+
+    // List of final conversion transformations that must to be executed
+    // after previous group of transformations
+    manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
+    manager.register_pass<ngraph::pass::ConvertNormalizeL2ToLegacyMatcher>();
+    manager.register_pass<ngraph::pass::ConvertMulAddToScaleShiftOrPower>();
+    manager.register_pass<ngraph::pass::ConvertMulOrAddFinally>();
+
+    manager.register_pass<ngraph::pass::ConstantFolding>();
 
-    for (auto & t : transforms) {
-        if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
-            t_param->setCallback(transformation_callback);
-        }
-    }
-    OpSet1ToLegacy.run_passes(f);
+    manager.set_callback(m_transformation_callback);
+    manager.run_passes(f);
     return true;
 }
\ No newline at end of file
index b2a4915..8a865d5 100644 (file)
@@ -9,34 +9,24 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertPadToPadIE::convert_pad() {
-    auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
-    auto input_2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
-    auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
-    auto pad_1 = std::make_shared<ngraph::opset1::Pad>(input_0, input_1, input_2, op::PadMode::SYMMETRIC);
-    auto pad_2 = std::make_shared<ngraph::opset1::Pad>(input_0, input_1, input_2, input_3, op::PadMode::CONSTANT);
+ngraph::pass::ConvertPadToLegacyMatcher::ConvertPadToLegacyMatcher() {
+    auto m_pad = ngraph::pattern::wrap_type<ngraph::opset1::Pad>();
 
-
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto pad = std::dynamic_pointer_cast<ngraph::opset1::Pad> (m.get_match_root());
         if (!pad) {
             return false;
         }
 
         auto pad_ie = std::make_shared<ngraph::op::PadIE>(pad);
-        if (pad_ie == nullptr)
-            return false;
         pad_ie->set_friendly_name(pad->get_friendly_name());
         ngraph::copy_runtime_info(pad, pad_ie);
         ngraph::replace_node(pad, pad_ie);
         return true;
     };
 
-    auto m1 = std::make_shared<ngraph::pattern::Matcher>(pad_1, "ConvertPadToPadIE");
-    this->add_matcher(m1, callback, PassProperty::CHANGE_DYNAMIC_STATE);
-
-    auto m2 = std::make_shared<ngraph::pattern::Matcher>(pad_2, "ConvertPadToPadIE");
-    this->add_matcher(m2, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_pad, "ConvertPadToLegacy");
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 5f98b4d..6ab97dc 100644 (file)
 #include <transformations/utils/utils.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertPowerToPowerIE::convert_power() {
+ngraph::pass::ConvertPowerToPowerIEMatcher::ConvertPowerToPowerIEMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto power = std::make_shared<ngraph::opset1::Power>(input_0, input_1);
 
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto power = std::dynamic_pointer_cast<ngraph::opset1::Power> (m.get_match_root());
         if (!power) {
             return false;
@@ -41,5 +41,5 @@ void ngraph::pass::ConvertPowerToPowerIE::convert_power() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(power, "ConvertPowerToPowerIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index b50dd36..036cdb1 100644 (file)
 #include <transformations/utils/utils.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertPReLUToReLUIE::convert_prelu() {
+ngraph::pass::ConvertPReLUToReLUIE::ConvertPReLUToReLUIE() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto prelu = std::make_shared<ngraph::opset1::PRelu>(input_0, input_1);
 
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto prelu = std::dynamic_pointer_cast<ngraph::opset1::PRelu> (m.get_match_root());
         if (!prelu) {
             return false;
@@ -41,5 +41,5 @@ void ngraph::pass::ConvertPReLUToReLUIE::convert_prelu() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, "ConvertPReLUToReLUIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 72de4e4..531f87e 100644 (file)
@@ -12,7 +12,7 @@
 #include <ngraph_ops/proposal_ie.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertProposalToProposalIE::convert_proposal() {
+ngraph::pass::ConvertProposalToLegacyMatcher::ConvertProposalToLegacyMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
@@ -21,7 +21,7 @@ void ngraph::pass::ConvertProposalToProposalIE::convert_proposal() {
 
     auto proposal = std::make_shared<ngraph::opset1::Proposal>(input_0, input_1, input_2, attr);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto proposal = std::dynamic_pointer_cast<ngraph::opset1::Proposal> (m.get_match_root());
 
         if (!proposal) {
@@ -60,6 +60,6 @@ void ngraph::pass::ConvertProposalToProposalIE::convert_proposal() {
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(proposal, "CPUFusion.ConvertProposalToProposalIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<ngraph::pattern::Matcher>(proposal, "ConvertProposalToProposalIE");
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 189a61a..0b85555 100644 (file)
 #include <transformations/utils/utils.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertSeluToSeluIE::convert_selu() {
+ngraph::pass::ConvertSeluToSeluIEMatcher::ConvertSeluToSeluIEMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto selu = std::make_shared<ngraph::opset1::Selu>(input_0, input_1, input_2);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto selu = std::dynamic_pointer_cast<ngraph::opset1::Selu> (m.get_match_root());
         if (!selu) {
             return false;
@@ -48,5 +48,5 @@ void ngraph::pass::ConvertSeluToSeluIE::convert_selu() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(selu, "ConvertSeluToSeluIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index c83b15f..4f552e4 100644 (file)
 #include <transformations/utils/utils.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertSqrtToPowerIE::convert_sqrt() {
+ngraph::pass::ConvertSqrtToPowerIEMatcher::ConvertSqrtToPowerIEMatcher() {
     auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
     auto sqrt = std::make_shared<ngraph::opset1::Sqrt>(input_0);
 
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto sqrt = std::dynamic_pointer_cast<ngraph::opset1::Sqrt>(m.get_match_root());
         if (!sqrt) {
             return false;
@@ -31,6 +31,6 @@ void ngraph::pass::ConvertSqrtToPowerIE::convert_sqrt() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(sqrt, "ConvertPowerToPowerIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
 
index 59b1593..69ee25e 100644 (file)
@@ -13,7 +13,7 @@
 #include <ngraph_ops/crop_ie.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
+ngraph::pass::ConvertStridedSliceToCropMatcher::ConvertStridedSliceToCropMatcher() {
     auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto m_begin = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
     auto m_end = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
@@ -22,7 +22,7 @@ void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
     std::vector<int64_t> end_mask = {0, 0, 0, 0};
     auto m_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, m_begin, m_end, m_stride, begin_mask, end_mask);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto slice = std::dynamic_pointer_cast<ngraph::opset1::StridedSlice> (m.get_match_root());
         if (!slice) {
             return false;
@@ -227,5 +227,5 @@ void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(m_slice, "ConvertStridedSliceToCrop");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index a9863ed..712d90d 100644 (file)
 #include <ngraph_ops/tile_ie.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::ConvertTileToIETile::convert_tile() {
+ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() {
     auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
     auto tile = std::make_shared<ngraph::opset1::Tile>(data, shp);
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
         auto tile = std::dynamic_pointer_cast<ngraph::opset1::Tile> (m.get_match_root());
         if (!tile) {
             return false;
@@ -90,6 +90,6 @@ void ngraph::pass::ConvertTileToIETile::convert_tile() {
         return true;
     };
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(tile, "CPUFusion.ConvertTileToIETiles");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    auto m = std::make_shared<ngraph::pattern::Matcher>(tile, "ConvertTileToIETiles");
+    this->register_matcher(m, callback);
 }
index 323478d..4beadff 100644 (file)
 
 #include <ngraph_ops/topk_ie.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertTopKToTopKIE::convert_topk_to_topk_ie() {
-    auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::TopK>());
+ngraph::pass::ConvertTopKToTopKIEMatcher::ConvertTopKToTopKIEMatcher() {
+    auto topk = ngraph::pattern::wrap_type<opset1::TopK>();
 
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+    ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
         auto topk = std::dynamic_pointer_cast<opset1::TopK>(m.get_match_root());
         if (!topk || topk->input(1).get_partial_shape().rank().is_dynamic()) {
             return false;
@@ -74,5 +75,5 @@ void ngraph::pass::ConvertTopKToTopKIE::convert_topk_to_topk_ie() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(topk, "ConvertTopKToTopKIE");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/fc_bias_fusion.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/fc_bias_fusion.cpp
new file mode 100644 (file)
index 0000000..9d7bda6
--- /dev/null
@@ -0,0 +1,74 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_opset1_to_legacy/fc_bias_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
+    auto fc = ngraph::pattern::wrap_type<op::FullyConnected>();
+    auto add = ngraph::pattern::wrap_type<opset1::Add>({fc, std::make_shared<pattern::op::Label>()});
+
+    ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+        auto add = m.get_match_root();
+        auto add_input_0 = add->input(0).get_source_output().get_node_shared_ptr();
+        auto add_input_1 = add->input(1).get_source_output().get_node_shared_ptr();
+
+        auto m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_0);
+        auto m_bias = add_input_1;
+
+        if (m_fc == nullptr) {
+            m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_1);
+            m_bias = add_input_0;
+        }
+
+        if (auto bcast_m = std::dynamic_pointer_cast<opset1::Broadcast>(m_bias)) {
+            m_bias = bcast_m->input(0).get_source_output().get_node_shared_ptr();
+        }
+
+        if (!std::dynamic_pointer_cast<opset1::Constant>(m_bias)) {
+            return false;
+        }
+        Shape bias_shape(m_bias->get_shape());
+
+        if (m_fc->output(0).get_target_inputs().size() != 1) {
+            return false;
+        }
+
+        Shape output_shape(m_fc->get_shape());
+        size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<int64_t>());
+        if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
+            return false;
+        }
+
+        NodeVector new_ops;
+
+        auto new_bias = std::make_shared<opset1::Add>(m_fc->input(2).get_source_output(), m_bias);
+        new_ops.push_back(new_bias);
+        std::shared_ptr<Node> final_bias = new_bias;
+        if (new_bias->get_shape().size() >= 2) {
+            final_bias = std::make_shared<opset1::Reshape>(final_bias, opset1::Constant::create(element::i64, Shape{1}, {-1}), true);
+            new_ops.push_back(final_bias);
+        }
+
+        auto new_fc = std::make_shared<op::FullyConnected>(m_fc->input(0).get_source_output(),
+                                                           m_fc->input(1).get_source_output(),
+                                                           final_bias,
+                                                           m_fc->get_shape());
+        new_ops.push_back(new_fc);
+
+        new_fc->set_friendly_name(add->get_friendly_name());
+        ngraph::copy_runtime_info({m_fc, add}, new_ops);
+        ngraph::replace_node(add, new_fc);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(add, "FullyConnectedBiasFusion");
+    this->register_matcher(m, callback);
+}
\ No newline at end of file
index 7860a98..d628088 100644 (file)
@@ -9,6 +9,7 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 #include "ngraph_ops/convolution_ie.hpp"
 #include "transformations/utils/utils.hpp"
@@ -103,14 +104,8 @@ std::shared_ptr<Node> convert(const Output<Node> & data, std::shared_ptr<opset1:
                                              node->get_auto_pad());
 }
 
-void ngraph::pass::Reshape1DOps::reshape_ops() {
-    auto node = std::make_shared<pattern::op::Label>(element::f32, Shape{},
-            [](const std::shared_ptr<Node> & node) {
-                return std::dynamic_pointer_cast<op::ConvolutionIE>(node) ||
-                       std::dynamic_pointer_cast<opset1::MaxPool>(node) ||
-                       std::dynamic_pointer_cast<opset1::AvgPool>(node);});
-
-    ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+matcher_pass_callback get_callback() {
+    return [](pattern::Matcher& m) {
         auto node = m.get_match_root();
         if (!node || node->input(0).get_partial_shape().rank().get_length() != 3) {
             return false;
@@ -150,7 +145,22 @@ void ngraph::pass::Reshape1DOps::reshape_ops() {
         node->output(0).replace(last);
         return true;
     };
+}
+
+ngraph::pass::Reshape1DConvolution::Reshape1DConvolution() {
+    auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>();
+    auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution");
+    this->register_matcher(m, get_callback());
+}
+
+ngraph::pass::Reshape1DAvgPool::Reshape1DAvgPool() {
+    auto pool = ngraph::pattern::wrap_type<opset1::AvgPool>();
+    auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool");
+    this->register_matcher(m, get_callback());
+}
 
-    auto m = std::make_shared<ngraph::pattern::Matcher>(node, "Reshape1DOps");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ngraph::pass::Reshape1DMaxPool::Reshape1DMaxPool() {
+    auto pool = ngraph::pattern::wrap_type<opset1::MaxPool>();
+    auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool");
+    this->register_matcher(m, get_callback());
 }
\ No newline at end of file
index c43fa65..e598666 100644 (file)
@@ -13,7 +13,7 @@
 #include "ngraph_ops/fully_connected.hpp"
 #include "transformations/utils/utils.hpp"
 
-void ngraph::pass::ReshapeFullyConnected::reshape_fully_connected() {
+ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {
     auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1});
     auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1});
     auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
@@ -21,7 +21,7 @@ void ngraph::pass::ReshapeFullyConnected::reshape_fully_connected() {
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected> (m.get_match_root());
-        if (!fc || transformation_callback(fc)) {
+        if (!fc || m_transformation_callback(fc)) {
             return false;
         }
 
@@ -70,5 +70,5 @@ void ngraph::pass::ReshapeFullyConnected::reshape_fully_connected() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(fc, "ReshapeFullyConnected");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 72b7668..51415bc 100644 (file)
 #include <ngraph/pass/manager.hpp>
 
 bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) {
-    ngraph::pass::Manager OpSet2ToOpSet1;
-    std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
+    ngraph::pass::Manager manager;
 
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(OpSet2ToOpSet1.register_pass<NAMESPACE::NAME>());
-#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1_tbl.hpp>
-#undef NGRAPH_PASS
+    manager.register_pass<ngraph::pass::ConvertGELU>();
+    manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
+    manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
 
-    for (auto & t : transforms) {
-        if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
-            t_param->setCallback(transformation_callback);
-        }
-    }
-    OpSet2ToOpSet1.run_passes(f);
+    manager.set_callback(m_transformation_callback);
+    manager.run_passes(f);
     return true;
 }
\ No newline at end of file
index 5a27e03..e0cd4a2 100644 (file)
 #include <ngraph/pass/manager.hpp>
 
 bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) {
-    ngraph::pass::Manager OpSet3ToOpSet2;
-    std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
+    ngraph::pass::Manager manager;
 
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(OpSet3ToOpSet2.register_pass<NAMESPACE::NAME>());
-#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp>
-#undef NGRAPH_PASS
+    manager.register_pass<ngraph::pass::ConvertBroadcast3>();
+    manager.register_pass<ngraph::pass::ConvertNMS3>();
+    manager.register_pass<ngraph::pass::ConvertShapeOf3>();
+    manager.register_pass<ngraph::pass::ConvertShuffleChannels3>();
+    manager.register_pass<ngraph::pass::ConvertTopK3>();
 
-    for (auto & t : transforms) {
-        if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
-            t_param->setCallback(transformation_callback);
-        }
-    }
-    OpSet3ToOpSet2.run_passes(f);
+    manager.set_callback(m_transformation_callback);
+    manager.run_passes(f);
     return true;
 }
index b333471..9150ede 100644 (file)
@@ -19,7 +19,7 @@ void ngraph::pass::ConvertShuffleChannels3::convert_shuffle_channels3() {
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher &m) {
         auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());
-        if (!shuffle_channels || transformation_callback(shuffle_channels)) {
+        if (!shuffle_channels || m_transformation_callback(shuffle_channels)) {
             return false;
         }
         if (shuffle_channels->input_value(0).get_partial_shape().rank().is_dynamic()) {
index 0c42fe6..4d4ed66 100644 (file)
@@ -48,7 +48,7 @@ void ngraph::pass::ConvertSpaceToBatch::convert_space_to_batch_by_elements() {
         auto data = space_to_batch->input_value(0);
         const auto& data_shape = data.get_shape();
 
-        if (transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
+        if (m_transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
             return false;
         }
 
index 5eb1655..a732b38 100644 (file)
@@ -9,14 +9,14 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertSpaceToDepth::convert() {
-    auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
-    auto dts = std::make_shared<ngraph::opset1::SpaceToDepth>(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST);
+ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
+    auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>();
 
     ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
         auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
-        if (!std_node || transformation_callback(std_node)) {
+        if (!std_node || m_transformation_callback(std_node)) {
             return false;
         }
 
@@ -88,5 +88,5 @@ void ngraph::pass::ConvertSpaceToDepth::convert() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(dts, "ConvertSpaceToDepth");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index 427f85d..5e57c47 100644 (file)
@@ -9,11 +9,10 @@
 
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
-void ngraph::pass::ConvertSubtract::convert_subtract() {
-    auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
-    auto sub = std::make_shared<ngraph::opset1::Subtract>(input0, input1);
+ngraph::pass::ConvertSubtract::ConvertSubtract() {
+    auto sub = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>();
 
     ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
         auto sub = std::dynamic_pointer_cast<ngraph::opset1::Subtract> (m.get_match_root());
@@ -33,5 +32,5 @@ void ngraph::pass::ConvertSubtract::convert_subtract() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(sub, "ConvertSubtract");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
\ No newline at end of file
index da0b41a..3a7d39b 100644 (file)
@@ -154,7 +154,7 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
         depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
         ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
 
-        if (!transformation_callback(depth_to_space)) {
+        if (!m_transformation_callback(depth_to_space)) {
             return false;
         }
 
diff --git a/inference-engine/src/transformations/src/transformations/lin_op_sequence_fusion.cpp b/inference-engine/src/transformations/src/transformations/lin_op_sequence_fusion.cpp
new file mode 100644 (file)
index 0000000..0c7996d
--- /dev/null
@@ -0,0 +1,128 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/lin_op_sequence_fusoin.hpp"
+#include "transformations/mul_add_squence_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+using namespace ngraph;
+
+template <class T>
+Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
+    auto eltwise = std::make_shared<T>(input0, input1);
+    OutputVector output(eltwise->get_output_size());
+    if (!eltwise->constant_fold(output, {input0, input1})) {
+        throw ngraph_error("Can not constant fold eltwise node");
+    }
+    if (output.size() != 1) {
+        throw ngraph_error("Eltwise constant fold has unexpected number of outputs: " + std::to_string(output.size()));
+    }
+    return output[0];
+}
+
+ngraph::pass::AddMultiplyFusion::AddMultiplyFusion() {
+    // Create Add->Multiply pattern where Add has exactly one consumer
+    auto m_data = ngraph::pattern::any_input();
+    auto m_add_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+    auto m_add = ngraph::pattern::wrap_type<opset3::Add>({m_data, m_add_constant}, pattern::consumers_count(1));
+    auto m_mul_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+    auto m_mul = ngraph::pattern::wrap_type<opset3::Multiply>({m_add, m_mul_constant});
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
+        auto & label_to_output = m.get_pattern_value_map();
+
+        auto mul = label_to_output[m_mul].get_node_shared_ptr();
+        auto add = label_to_output[m_add].get_node_shared_ptr();
+
+        Output<Node> input = label_to_output[m_data];
+        Output<Node> mul_const = label_to_output[m_mul_constant];
+        Output<Node> add_const = label_to_output[m_add_constant];
+
+        // Replace Add->Multiply with Multiply->Add
+        // As new Multiply can be fused with operation above it we add this Multiply
+        // to the list of operations that will be used in additional matching.
+        auto new_mul = register_new_node<opset3::Multiply>(input, mul_const);
+
+        // Add two constants using opset3::Add constant folding and create new Add operation
+        auto new_add = std::make_shared<opset3::Add>(new_mul, eltwise_fold<opset3::Multiply>(add_const, mul_const));
+
+        copy_runtime_info({add, mul}, {new_mul, new_add});
+        new_add->set_friendly_name(mul->get_friendly_name());
+        replace_node(mul, new_add);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_mul, "AddMultiplyFusion");
+    this->register_matcher(m, callback);
+}
+
+ngraph::pass::AddAddFusion::AddAddFusion() {
+    // Create Add->Add pattern where first Add has exactly one consumer
+    auto m_data = ngraph::pattern::any_input();
+    auto m_add1_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+    auto m_add1 = ngraph::pattern::wrap_type<opset3::Add>({m_data, m_add1_constant}, pattern::consumers_count(1));
+    auto m_add2_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+    auto m_add2 = ngraph::pattern::wrap_type<opset3::Add>({m_add1, m_add2_constant});
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
+        auto & label_to_output = m.get_pattern_value_map();
+
+        auto add1 = label_to_output[m_add1].get_node_shared_ptr();
+        auto add2 = label_to_output[m_add2].get_node_shared_ptr();
+
+        Output<Node> input = label_to_output[m_data];
+        Output<Node> add1_const = label_to_output[m_add1_constant];
+        Output<Node> add2_const = label_to_output[m_add2_constant];
+
+        // Replace Add->Add with single Add
+        // Add operation will be added to the list of ops requested for pattern matching
+        auto new_add = register_new_node<opset3::Add>(input, eltwise_fold<opset3::Add>(add1_const, add2_const));
+
+        copy_runtime_info({add1, add2}, new_add);
+        new_add->set_friendly_name(add2->get_friendly_name());
+        replace_node(add2, new_add);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_add2, "AddAddFusion");
+    this->register_matcher(m, callback);
+}
+
+ngraph::pass::MultiplyMultiplyFusion::MultiplyMultiplyFusion() {
+    // Create Multiply->Multiply pattern where first Multiply has exactly one consumer
+    auto m_data = ngraph::pattern::any_input();
+    auto m_mul1_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+    auto m_mul1 = ngraph::pattern::wrap_type<opset3::Multiply>({m_data, m_mul1_constant}, pattern::consumers_count(1));
+    auto m_mul2_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+    auto m_mul2 = ngraph::pattern::wrap_type<ngraph::opset3::Multiply>({m_mul1, m_mul2_constant});
+
+    ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
+        auto & label_to_output = m.get_pattern_value_map();
+
+        auto mul1 = label_to_output[m_mul1].get_node_shared_ptr();
+        auto mul2 = label_to_output[m_mul2].get_node_shared_ptr();
+
+        Output<Node> input = label_to_output[m_data];
+        Output<Node> mul1_const = label_to_output[m_mul1_constant];
+        Output<Node> mul2_const = label_to_output[m_mul2_constant];
+
+        // Replace Multiply->Multiply with single Multiply
+        // Multiply operation will be added to the list of ops requested for pattern matching
+        auto new_mul = register_new_node<opset3::Multiply>(input, eltwise_fold<opset3::Multiply>(mul1_const, mul2_const));
+
+        copy_runtime_info({mul1, mul2}, new_mul);
+        new_mul->set_friendly_name(mul2->get_friendly_name());
+        replace_node(mul2, new_mul);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(m_mul2, "MultiplyMultiplyFusion");
+    this->register_matcher(m, callback);
+}
index ea5efa1..787f0a4 100644 (file)
@@ -10,7 +10,7 @@
 #include <ngraph/opsets/opset1.hpp>
 #include <ngraph/rt_info.hpp>
 
-void ngraph::pass::PullTransposeThroughFQUp::pull_transpose_through_fq() {
+ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
     auto data1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto data2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
     auto data3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
@@ -70,5 +70,5 @@ void ngraph::pass::PullTransposeThroughFQUp::pull_transpose_through_fq() {
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "PullTransposeThroughFQUp");
-    this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+    this->register_matcher(m, callback);
 }
index 1ac3b79..21ef1f1 100644 (file)
@@ -385,9 +385,12 @@ ModelPtr FrontEnd::runCommonPasses(ie::ICNNNetwork& network, const UnsupportedLa
             // Disable shape inference (WA for generic operations)
             ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
 
-            ngraph::pass::ConvertOpSet3ToOpSet2(transformationsPredicate).run_on_function(nGraphFunc);
-            ngraph::pass::ConvertOpSet2ToOpSet1(transformationsPredicate).run_on_function(nGraphFunc);
-            ngraph::pass::ConvertOpSet1ToLegacy(transformationsPredicate).run_on_function(nGraphFunc);
+            ngraph::pass::Manager manager;
+            manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+            manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+            manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+            manager.set_callback(transformationsPredicate);
+            manager.run_passes(nGraphFunc);
 
             vpu::MergeSubsequentDSROperations().run_on_function(nGraphFunc);
 
index 2bde0d3..516cf85 100644 (file)
@@ -20,6 +20,7 @@
 #include <ngraph_ops/gru_cell_ie.hpp>
 #include <ngraph_ops/rnn_cell_ie.hpp>
 #include <ngraph_ops/lstm_cell_ie.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "common_test_utils/ngraph_test_utils.hpp"
 
 using namespace testing;
@@ -44,8 +45,10 @@ TEST(TransformationTests, GRUCellConversionTest) {
         cell->set_friendly_name("test_cell");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertCellsToCellsIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertGRUCellMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -93,8 +96,10 @@ TEST(TransformationTests, RNNCellConversionTest) {
         cell->set_friendly_name("test_cell");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertCellsToCellsIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertRNNCellMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -146,8 +151,10 @@ TEST(TransformationTests, LSTMCellConversionTest) {
         cell->set_friendly_name("test_cell");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t, C_t});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertCellsToCellsIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertLSTMCellMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
index 6995b66..655c8ab 100644 (file)
@@ -13,6 +13,7 @@
 #include <transformations/convert_divide.hpp>
 #include <transformations/init_node_info.hpp>
 #include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -27,8 +28,10 @@ TEST(TransformationTests, ConvertDivide) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertDivide().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertDivide>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -55,8 +58,10 @@ TEST(TransformationTests, ConvertDivideNegative) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertDivide().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertDivide>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
index 8edd748..d13c2b6 100644 (file)
@@ -14,6 +14,7 @@
 #include <transformations/init_node_info.hpp>
 #include <transformations/utils/utils.hpp>
 #include <ngraph_ops/gather_ie.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -30,8 +31,10 @@ TEST(TransformationTests, ConvertGatherToGatherIEStatic1) {
 
         f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertGatherToGatherIE().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
     }
@@ -58,8 +61,10 @@ TEST(TransformationTests, ConvertGatherToGatherIEStatic2) {
 
         f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertGatherToGatherIE().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
     }
@@ -88,8 +93,10 @@ TEST(TransformationTests, ConvertGatherToGatherIEDynamic1) {
 
         f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertGatherToGatherIE().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -115,8 +122,10 @@ TEST(TransformationTests, ConvertGatherToGatherIEDynamic2) {
 
         f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertGatherToGatherIE().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
index d006fda..fb2847d 100644 (file)
@@ -21,6 +21,7 @@
 #include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
 #include <transformations/init_node_info.hpp>
 #include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -35,8 +36,10 @@ TEST(TransformationTests, ConvertMatMulTest1) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -64,8 +67,10 @@ TEST(TransformationTests, ConvertMatMulTest2) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -92,8 +97,10 @@ TEST(TransformationTests, ConvertMatMulTest3) {
         auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -120,8 +127,10 @@ TEST(TransformationTests, ConvertMatMulTest4) {
         auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -145,8 +154,10 @@ TEST(TransformationTests, ConvertMatMulTest5) {
         auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -171,8 +182,12 @@ TEST(TransformationTests, ConvertMatMulTest6) {
         auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
-        ngraph::pass::ReshapeFullyConnected().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.register_pass<ngraph::pass::ReshapeFullyConnected>();
+        m.run_passes(f);
+        ASSERT_NO_THROW(check_rt_info(f));
     }
 
     {
@@ -199,8 +214,10 @@ TEST(TransformationTests, ConvertMatMulTest7) {
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+        m.register_pass<ngraph::pass::ReshapeFullyConnected>();
 
         auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
             if (auto fc_op = std::dynamic_pointer_cast<const ngraph::op::FullyConnected>(node)) {
@@ -210,9 +227,9 @@ TEST(TransformationTests, ConvertMatMulTest7) {
             }
             return false;
         };
-        auto p = ngraph::pass::ReshapeFullyConnected();
-        p.setCallback(callback);
-        p.run_on_function(f);
+
+        m.set_callback(callback);
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
index 035e6b9..0a41482 100644 (file)
@@ -13,6 +13,7 @@
 #include <transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.hpp>
 #include <transformations/init_node_info.hpp>
 #include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -33,8 +34,10 @@ TEST(TransformationTests, ConvertNMS4ToNMSIEStatic) {
         f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
 
         const auto &orig_shape = f->get_output_partial_shape(0);
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertNMS4ToLegacy().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertNMS4ToLegacyMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
     }
@@ -75,8 +78,10 @@ TEST(TransformationTests, ConvertNMS4ToNMSIEDynamic1) {
 
         f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertNMS4ToLegacy().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertNMS4ToLegacyMatcher>();
+        manager.run_passes(f);
         f->validate_nodes_and_infer_types();
         ASSERT_NO_THROW(check_rt_info(f));
     }
@@ -114,8 +119,11 @@ TEST(TransformationTests, ConvertNMS4ToNMSIEDynamic2) {
 
         f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertNMS4ToLegacy().run_on_function(f);
+        pass::Manager manager;
+        manager.register_pass<pass::InitNodeInfo>();
+        manager.register_pass<pass::ConvertNMS4ToLegacyMatcher>();
+        manager.run_passes(f);
+
         f->validate_nodes_and_infer_types();
         ASSERT_NO_THROW(check_rt_info(f));
     }
index eec1aa2..244d8b4 100644 (file)
@@ -15,6 +15,7 @@
 #include <transformations/utils/utils.hpp>
 #include <ngraph_ops/nms_ie.hpp>
 #include <ngraph/pass/constant_folding.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -35,8 +36,10 @@ TEST(TransformationTests, ConvertNMSToNMSIEStatic) {
         f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
 
         const auto & orig_shape = f->get_output_partial_shape(0);
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertNMSToNMSIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
     }
@@ -73,8 +76,10 @@ TEST(TransformationTests, ConvertNMSToNMSIEDynamic1) {
 
         f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertNMSToNMSIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+        manager.run_passes(f);
         f->validate_nodes_and_infer_types();
         ASSERT_NO_THROW(check_rt_info(f));
     }
@@ -110,8 +115,10 @@ TEST(TransformationTests, ConvertNMSToNMSIEDynamic2) {
 
         f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
 
-        pass::InitNodeInfo().run_on_function(f);
-        pass::ConvertNMSToNMSIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+        manager.run_passes(f);
         f->validate_nodes_and_infer_types();
         ASSERT_NO_THROW(check_rt_info(f));
     }
index d7fb40e..9a09144 100644 (file)
@@ -23,6 +23,7 @@
 #include <ngraph/op/reshape.hpp>
 #include <transformations/utils/utils.hpp>
 #include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -48,8 +49,10 @@ TEST(TransformationTests, ConvertStridedSliceToCropTests1) {
         sslice->set_friendly_name("strided_slice");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -100,8 +103,10 @@ TEST(TransformationTests, ConvertStridedSliceToCropTests2) {
         sslice->set_friendly_name("strided_slice");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -152,8 +157,10 @@ TEST(TransformationTests, ConvertStridedSliceToCropNegative) {
         sslice->set_friendly_name("strided_slice");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -202,8 +209,10 @@ TEST(TransformationTests, ConvertStridedSliceToCropNegative2) {
         sslice->set_friendly_name("strided_slice");
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
index f8bfeae..37a89d4 100644 (file)
@@ -15,6 +15,7 @@
 #include <transformations/utils/utils.hpp>
 #include <ngraph_ops/topk_ie.hpp>
 #include <ngraph/pass/constant_folding.hpp>
+#include <ngraph/pass/manager.hpp>
 
 #include "common_test_utils/ngraph_test_utils.hpp"
 
@@ -29,10 +30,14 @@ TEST(TransformationTests, ConvertTopKToTopKIEStatic) {
         // due to the 'compare_functions' limitation we will check only one output
         f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
-        ASSERT_NO_THROW(check_rt_info(f));
-        ngraph::pass::ConstantFolding().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+        manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+            check_rt_info(f);
+        });
+        manager.register_pass<ngraph::pass::ConstantFolding>();
+        ASSERT_NO_THROW(manager.run_passes(f));
         ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
     }
 
@@ -59,8 +64,10 @@ TEST(TransformationTests, ConvertTopKToTopKIEDynamic1) {
         // due to the 'compare_functions' limitation we will check only one output
         f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ngraph::pass::ConstantFolding().run_on_function(f);
     }
@@ -88,8 +95,10 @@ TEST(TransformationTests, ConvertTopKToTopKIEDynamic2) {
         // due to the 'compare_functions' limitation we will check only one output
         f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ngraph::pass::ConstantFolding().run_on_function(f);
     }
@@ -117,8 +126,10 @@ TEST(TransformationTests, ConvertTopKToTopKIEDynamic3) {
         // due to the 'compare_functions' limitation we will check only one output
         f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ngraph::pass::ConstantFolding().run_on_function(f);
     }
@@ -146,8 +157,10 @@ TEST(TransformationTests, ConvertTopKToTopKIENegative) {
         // due to the 'compare_functions' limitation we will check only one output
         f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input, k});
 
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+        manager.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
         ngraph::pass::ConstantFolding().run_on_function(f);
     }
index c9edda8..1b42d69 100644 (file)
@@ -43,7 +43,7 @@ TEST(TransformationTests, DepthToSpaceFusionDepthFirst) {
         };
 
         auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
-        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.set_callback(callback);
         depth_to_space_transform.run_on_function(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
@@ -77,7 +77,7 @@ TEST(TransformationTests, DepthToSpaceFusionBlockFirst) {
         };
 
         auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
-        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.set_callback(callback);
         depth_to_space_transform.run_on_function(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
@@ -112,7 +112,7 @@ TEST(TransformationTests, DepthToSpaceFusionDynamicShape) {
 
         // transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
         auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
-        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.set_callback(callback);
         depth_to_space_transform.run_on_function(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
@@ -157,7 +157,7 @@ TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) {
 
         // transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
         auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
-        depth_to_space_transform.setCallback(callback);
+        depth_to_space_transform.set_callback(callback);
         depth_to_space_transform.run_on_function(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
index cebab50..9192179 100644 (file)
@@ -36,10 +36,15 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest3D) {
         auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::FullyConnectedBiasFusion().run_on_function(f);
-        ASSERT_NO_THROW(check_rt_info(f));
-        ngraph::pass::ConstantFolding().run_on_function(f);
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::FullyConnectedBiasFusion>();
+        manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+            check_rt_info(f);
+        });
+        manager.register_pass<ngraph::pass::ConstantFolding>();
+        ASSERT_NO_THROW(manager.run_passes(f));
     }
 
     {
@@ -67,10 +72,14 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest2D) {
         auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::FullyConnectedBiasFusion().run_on_function(f);
-        ASSERT_NO_THROW(check_rt_info(f));
-        ngraph::pass::ConstantFolding().run_on_function(f);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::FullyConnectedBiasFusion>();
+        manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+            check_rt_info(f);
+        });
+        manager.register_pass<ngraph::pass::ConstantFolding>();
+        ASSERT_NO_THROW(manager.run_passes(f));
     }
 
     {
diff --git a/inference-engine/tests/functional/inference_engine/transformations/lin_op_sequence_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/lin_op_sequence_fusion_test.cpp
new file mode 100644 (file)
index 0000000..c7f84d1
--- /dev/null
@@ -0,0 +1,174 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include "common_test_utils/test_common.hpp"
+#include <string>
+#include <sstream>
+#include <fstream>
+#include <memory>
+#include <queue>
+#include <map>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/lin_op_sequence_fusoin.hpp>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/visualize_tree.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+using namespace ngraph;
+
+TEST(TransformationTests, MulAddMulAddFusion) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+        auto mul2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+        auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {4});
+        auto add2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {5});
+
+        auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+        auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+        auto mul2 = std::make_shared<opset3::Multiply>(add1, mul2_const);
+        auto add2 = std::make_shared<opset3::Add>(mul2, add2_const);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add2}, ngraph::ParameterVector{input});
+    }
+
+    pass::Manager manager;
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+    manager.run_passes(f);
+    ASSERT_NO_THROW(check_rt_info(f));
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {6});
+        auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {17});
+
+        auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+        auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add1}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, MulMulMulFusion) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+        auto mul2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+        auto mul3_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
+
+        auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+        auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
+        auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul2}, ngraph::ParameterVector{input});
+    }
+
+    pass::Manager manager;
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+    manager.run_passes(f);
+    ASSERT_NO_THROW(check_rt_info(f));
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {12});
+
+        auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul1}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, AddAddAddFusion) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+        auto add2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+        auto add3_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
+
+        auto add1 = std::make_shared<opset3::Add>(input, add1_const);
+        auto add2 = std::make_shared<opset3::Add>(add1, add2_const);
+        auto add3 = std::make_shared<opset3::Add>(add2, add3_const);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add3}, ngraph::ParameterVector{input});
+    }
+
+    pass::Manager manager;
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+    manager.run_passes(f);
+    ASSERT_NO_THROW(check_rt_info(f));
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {8});
+
+        auto add1 = std::make_shared<opset3::Add>(input, add1_const);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add1}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, MulAddAddMulFusion) {
+    std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+        auto mul2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+        auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {4});
+        auto add2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {5});
+
+        auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+        auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+        auto add2 = std::make_shared<opset3::Add>(add1, add2_const);
+        auto mul2 = std::make_shared<opset3::Multiply>(add2, mul2_const);
+
+        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul2}, ngraph::ParameterVector{input});
+    }
+
+    pass::Manager manager;
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+    manager.run_passes(f);
+    ASSERT_NO_THROW(check_rt_info(f));
+
+    {
+        auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+        auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {10});
+        auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {40});
+
+        auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+        auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+
+        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add1}, ngraph::ParameterVector{input});
+    }
+
+    auto res = compare_functions(f, f_ref);
+    ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file
index 57542bf..691e19d 100644 (file)
@@ -17,6 +17,7 @@
 #include <transformations/convert_depth_to_space.hpp>
 #include <transformations/convert_space_to_depth.hpp>
 #include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "common_test_utils/ngraph_test_utils.hpp"
 
 using namespace testing;
@@ -28,8 +29,10 @@ TEST(TransformationTests, TestDepthToSpaceTransformBlockFirst) {
     {
         auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertDepthToSpace().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertDepthToSpace>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -66,8 +69,10 @@ TEST(TransformationTests, TestDepthToSpaceTransformDepthFirst) {
     {
         auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, 2);
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertDepthToSpace().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertDepthToSpace>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -104,8 +109,10 @@ TEST(TransformationTests, TestSpaceToDepthTransformBlockFirst) {
     {
         auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, 2);
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertSpaceToDepth().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
@@ -142,8 +149,10 @@ TEST(TransformationTests, TestSpaceToDepthTransformDepthFirst) {
     {
         auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST, 2);
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertSpaceToDepth().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
 
index 7100601..5bbe606 100644 (file)
@@ -18,6 +18,7 @@
 #include <transformations/pull_transpose_through_fq.hpp>
 #include <ngraph/pass/constant_folding.hpp>
 #include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "common_test_utils/ngraph_test_utils.hpp"
 
 using namespace testing;
@@ -36,10 +37,15 @@ TEST(TransformationTests, FQTransposeTest1) {
         auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::PullTransposeThroughFQUp().run_on_function(f);
-        ASSERT_NO_THROW(check_rt_info(f));
-        ngraph::pass::ConstantFolding().run_on_function(f);
+
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::InitNodeInfo>();
+        manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
+        manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+            check_rt_info(f);
+        });
+        manager.register_pass<ngraph::pass::ConstantFolding>();
+        ASSERT_NO_THROW(manager.run_passes(f));
     }
     std::vector<size_t> ref_shape{1, 3, 1};
     for (auto op : f->get_ops()) {
index dae7581..277b708 100644 (file)
@@ -17,6 +17,7 @@
 #include <transformations/convert_mod.hpp>
 #include <ngraph/pass/constant_folding.hpp>
 #include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "common_test_utils/ngraph_test_utils.hpp"
 
 using namespace testing;
@@ -30,8 +31,10 @@ TEST(TransformationTests, ModDecompositionTests) {
         auto mod = std::make_shared<ngraph::op::v1::Mod>(data1, data2);
 
         f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mod}, ngraph::ParameterVector{});
-        ngraph::pass::InitNodeInfo().run_on_function(f);
-        ngraph::pass::ConvertMod().run_on_function(f);
+        ngraph::pass::Manager m;
+        m.register_pass<ngraph::pass::InitNodeInfo>();
+        m.register_pass<ngraph::pass::ConvertMod>();
+        m.run_passes(f);
         ASSERT_NO_THROW(check_rt_info(f));
     }
     ASSERT_EQ(f->get_ops().size(), 12);
index 8718d95..26ca459 100644 (file)
@@ -23,6 +23,7 @@
 #include <ngraph/opsets/opset2.hpp>
 #include <ngraph/opsets/opset3.hpp>
 #include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "ngraph_functions/pass/convert_prc.hpp"
 
 #include "common_test_utils/common_utils.hpp"
@@ -77,10 +78,14 @@ InferenceEngine::CNNNetwork LayerTransformation::transform(InferenceEngine::deta
         ::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
 
         // Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
-        ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::CommonOptimizations>();
+        manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+        manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+        manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+        manager.set_callback(transformations_callback);
+        manager.run_passes(nGraphFunc);
         clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
     }
 
index b555b96..2a105b4 100644 (file)
@@ -23,6 +23,7 @@
 #include <ngraph/opsets/opset2.hpp>
 #include <ngraph/opsets/opset3.hpp>
 #include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/pass/manager.hpp>
 #include "ngraph_functions/pass/convert_prc.hpp"
 
 #include "common_test_utils/common_utils.hpp"
@@ -71,10 +72,14 @@ InferenceEngine::CNNNetwork LayerTransformation::transform(InferenceEngine::deta
         ::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
 
         // Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
-        ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
-        ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+        ngraph::pass::Manager manager;
+        manager.register_pass<ngraph::pass::CommonOptimizations>();
+        manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+        manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+        manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+        manager.set_callback(transformations_callback);
+        manager.run_passes(nGraphFunc);
         clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
     }
 
index 0c21e37..cfcf2b7 100644 (file)
@@ -8,6 +8,7 @@
 
 #include <ngraph/function.hpp>
 #include <ngraph/dimension.hpp>
+#include <ngraph/pass/pass.hpp>
 
 #include "test_common.hpp"
 
@@ -19,4 +20,27 @@ std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngraph::Fun
 
 void check_rt_info(const std::shared_ptr<ngraph::Function> & f);
 
-void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name);
\ No newline at end of file
+void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name);
+
+namespace ngraph {
+namespace pass {
+
+class InjectionPass;
+
+} // namespace pass
+} // namespace ngraph
+
+class ngraph::pass::InjectionPass : public ngraph::pass::FunctionPass {
+public:
+    using injection_callback = std::function<void(std::shared_ptr<ngraph::Function>)>;
+
+    explicit InjectionPass(injection_callback callback) : FunctionPass(), m_callback(std::move(callback)) {}
+
+    bool run_on_function(std::shared_ptr<ngraph::Function> f) override {
+        m_callback(f);
+        return false;
+    }
+
+private:
+    injection_callback m_callback;
+};
index 051dc00..9bf8990 100644 (file)
@@ -17,6 +17,7 @@ addIeTargetTest(
         ROOT ${CMAKE_CURRENT_SOURCE_DIR}
         LINK_LIBRARIES
             unitTestUtils
+            inference_engine_transformations
             ${OpenCV_LIBRARIES}
         ADD_CPPLINT
         DEPENDENCIES
index 33dfe80..655b53a 100644 (file)
@@ -560,6 +560,8 @@ set (SRC
     pattern/op/skip.hpp
     pattern/op/true.cpp
     pattern/op/true.hpp
+    pattern/op/wrap_type.hpp
+    pattern/op/wrap_type.cpp
     provenance.cpp
     provenance.hpp
     rank.hpp
index dc371b2..e640ba9 100644 (file)
@@ -800,8 +800,14 @@ bool Node::match_value(pattern::Matcher* matcher,
 bool Node::match_node(pattern::Matcher* matcher, const Output<Node>& graph_value)
 {
     matcher->add_node(graph_value);
-    return graph_value.get_node_shared_ptr()->get_type_info() == get_type_info() &&
-           matcher->match_arguments(this, graph_value.get_node_shared_ptr());
+    if (graph_value.get_node_shared_ptr()->get_type_info() == get_type_info() &&
+        matcher->match_arguments(this, graph_value.get_node_shared_ptr()))
+    {
+        auto& pattern_map = matcher->get_pattern_value_map();
+        pattern_map[shared_from_this()] = graph_value;
+        return true;
+    }
+    return false;
 }
 
 // default implementation for the node to check if it contains partial shape
index 3f9b9ac..57adf4b 100644 (file)
@@ -34,29 +34,30 @@ bool ngraph::pass::revalidate_and_ensure_static(shared_ptr<Node> n)
 
 void ngraph::pass::ConstantFolding::construct_constant_default()
 {
-    add_handler("Constant folding defaults",
-                [](const std::shared_ptr<Node>& node) -> bool {
-                    OutputVector replacements(node->get_output_size());
-                    if (!node->constant_fold(replacements, node->input_values()))
-                    {
-                        return false;
-                    }
-                    NGRAPH_CHECK(
-                        replacements.size() == node->get_output_size(),
-                        "constant_fold_default returned incorrect number of replacements for ",
-                        node);
-                    bool result{false};
-                    for (size_t i = 0; i < replacements.size(); ++i)
-                    {
-                        auto node_output = node->output(i);
-                        auto replacement = replacements.at(i);
-                        if (replacement.get_node_shared_ptr() && (node_output != replacement))
-                        {
-                            node_output.replace(replacement);
-                            result = true;
-                        }
-                    }
-                    return result;
-                },
-                PassProperty::CHANGE_DYNAMIC_STATE);
+    m_matchers.push_back(std::make_shared<MatcherPass>(
+        "Constant folding defaults",
+        nullptr,
+        [](const std::shared_ptr<Node>& node) -> bool {
+            OutputVector replacements(node->get_output_size());
+            if (!node->constant_fold(replacements, node->input_values()))
+            {
+                return false;
+            }
+            NGRAPH_CHECK(replacements.size() == node->get_output_size(),
+                         "constant_fold_default returned incorrect number of replacements for ",
+                         node);
+            bool result{false};
+            for (size_t i = 0; i < replacements.size(); ++i)
+            {
+                auto node_output = node->output(i);
+                auto replacement = replacements.at(i);
+                if (replacement.get_node_shared_ptr() && (node_output != replacement))
+                {
+                    node_output.replace(replacement);
+                    result = true;
+                }
+            }
+            return result;
+        },
+        PassProperty::CHANGE_DYNAMIC_STATE));
 }
index d1acf84..68d1cbc 100644 (file)
 #include "constant_folding.hpp"
 #include "ngraph/builder/split.hpp"
 #include "ngraph/op/constant.hpp"
+#include "ngraph/op/slice.hpp"
 #include "ngraph/op/split.hpp"
+#include "ngraph/runtime/reference/slice.hpp"
 #include "ngraph/validation_util.hpp"
 
 using namespace std;
 using namespace ngraph;
 
+template <class T>
+shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
+                                             shared_ptr<op::Slice> slice)
+{
+    const Shape& out_shape = slice->get_shape();
+    runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
+    T* data_ptr = buffer.get_ptr<T>();
+
+    runtime::reference::slice<T>(constant->get_data_ptr<T>(),
+                                 data_ptr,
+                                 constant->get_shape(),
+                                 slice->get_lower_bounds(),
+                                 slice->get_upper_bounds(),
+                                 slice->get_strides(),
+                                 out_shape);
+
+    return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
+}
+
 void pass::ConstantFolding::construct_constant_split()
 {
     auto data_label = make_shared<pattern::op::Label>(
@@ -51,7 +72,69 @@ void pass::ConstantFolding::construct_constant_split()
             output.replace(slices[index++]->output(0));
         }
         split->outputs().clear();
-        construct_constant_slice();
+
+        for (auto& slice : slices)
+        {
+            auto const_data = std::dynamic_pointer_cast<op::Constant>(
+                slice->input_value(0).get_node_shared_ptr());
+            auto slice_node = std::dynamic_pointer_cast<op::Slice>(slice);
+            if (!const_data || !slice_node)
+                continue;
+
+            std::shared_ptr<op::Constant> replacement;
+            switch (slice->get_output_element_type(0))
+            {
+            case element::Type_t::undefined:
+                NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_slice");
+                break;
+            case element::Type_t::dynamic:
+                NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
+                break;
+            case element::Type_t::u1:
+                NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_slice");
+                break;
+            case element::Type_t::boolean:
+                replacement = fold_constant_slice<char>(const_data, slice_node);
+                break;
+            case element::Type_t::bf16:
+                replacement = fold_constant_slice<bfloat16>(const_data, slice_node);
+                break;
+            case element::Type_t::f16:
+                replacement = fold_constant_slice<float16>(const_data, slice_node);
+                break;
+            case element::Type_t::f32:
+                replacement = fold_constant_slice<float>(const_data, slice_node);
+                break;
+            case element::Type_t::f64:
+                replacement = fold_constant_slice<double>(const_data, slice_node);
+                break;
+            case element::Type_t::i8:
+                replacement = fold_constant_slice<int8_t>(const_data, slice_node);
+                break;
+            case element::Type_t::i16:
+                replacement = fold_constant_slice<int16_t>(const_data, slice_node);
+                break;
+            case element::Type_t::i32:
+                replacement = fold_constant_slice<int32_t>(const_data, slice_node);
+                break;
+            case element::Type_t::i64:
+                replacement = fold_constant_slice<int64_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u8:
+                replacement = fold_constant_slice<uint8_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u16:
+                replacement = fold_constant_slice<uint16_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u32:
+                replacement = fold_constant_slice<uint32_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u64:
+                replacement = fold_constant_slice<uint64_t>(const_data, slice_node);
+                break;
+            }
+            replace_node(slice_node, replacement);
+        }
 
         return true;
     };
index 3782253..1428768 100644 (file)
 #include "constant_folding.hpp"
 #include "ngraph/builder/split.hpp"
 #include "ngraph/op/constant.hpp"
+#include "ngraph/op/slice.hpp"
 #include "ngraph/op/variadic_split.hpp"
+#include "ngraph/runtime/reference/slice.hpp"
 #include "ngraph/validation_util.hpp"
 
 using namespace std;
 using namespace ngraph;
 
+template <class T>
+shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
+                                             shared_ptr<op::Slice> slice)
+{
+    const Shape& out_shape = slice->get_shape();
+    runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
+    T* data_ptr = buffer.get_ptr<T>();
+
+    runtime::reference::slice<T>(constant->get_data_ptr<T>(),
+                                 data_ptr,
+                                 constant->get_shape(),
+                                 slice->get_lower_bounds(),
+                                 slice->get_upper_bounds(),
+                                 slice->get_strides(),
+                                 out_shape);
+
+    return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
+}
+
 void pass::ConstantFolding::construct_constant_variadic_split()
 {
     auto data_label = make_shared<pattern::op::Label>(
@@ -82,7 +103,69 @@ void pass::ConstantFolding::construct_constant_variadic_split()
             }
         }
         variadic_split->outputs().clear();
-        construct_constant_slice();
+
+        for (auto& slice : slices)
+        {
+            auto const_data = std::dynamic_pointer_cast<op::Constant>(
+                slice->input_value(0).get_node_shared_ptr());
+            auto slice_node = std::dynamic_pointer_cast<op::Slice>(slice);
+            if (!const_data || !slice_node)
+                continue;
+
+            std::shared_ptr<op::Constant> replacement;
+            switch (slice->get_output_element_type(0))
+            {
+            case element::Type_t::undefined:
+                NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_slice");
+                break;
+            case element::Type_t::dynamic:
+                NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
+                break;
+            case element::Type_t::u1:
+                NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_slice");
+                break;
+            case element::Type_t::boolean:
+                replacement = fold_constant_slice<char>(const_data, slice_node);
+                break;
+            case element::Type_t::bf16:
+                replacement = fold_constant_slice<bfloat16>(const_data, slice_node);
+                break;
+            case element::Type_t::f16:
+                replacement = fold_constant_slice<float16>(const_data, slice_node);
+                break;
+            case element::Type_t::f32:
+                replacement = fold_constant_slice<float>(const_data, slice_node);
+                break;
+            case element::Type_t::f64:
+                replacement = fold_constant_slice<double>(const_data, slice_node);
+                break;
+            case element::Type_t::i8:
+                replacement = fold_constant_slice<int8_t>(const_data, slice_node);
+                break;
+            case element::Type_t::i16:
+                replacement = fold_constant_slice<int16_t>(const_data, slice_node);
+                break;
+            case element::Type_t::i32:
+                replacement = fold_constant_slice<int32_t>(const_data, slice_node);
+                break;
+            case element::Type_t::i64:
+                replacement = fold_constant_slice<int64_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u8:
+                replacement = fold_constant_slice<uint8_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u16:
+                replacement = fold_constant_slice<uint16_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u32:
+                replacement = fold_constant_slice<uint32_t>(const_data, slice_node);
+                break;
+            case element::Type_t::u64:
+                replacement = fold_constant_slice<uint64_t>(const_data, slice_node);
+                break;
+            }
+            replace_node(slice_node, replacement);
+        }
 
         return true;
     };
index 2c07151..2b6dc11 100644 (file)
@@ -15,7 +15,9 @@
 //*****************************************************************************
 
 #include <algorithm>
+#include <deque>
 #include <iostream>
+#include <pattern/op/wrap_type.hpp>
 #include <regex>
 #include <unordered_set>
 #include <vector>
@@ -37,9 +39,13 @@ using namespace ngraph;
 //
 // The topological order would be : `Constant1`, `Abs2`, `Neg3`, `Add4`, `Result5`
 // Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3)
-// Next, GraphRewrite will invoke matchers registered in an order registered in a c-tor
-// i.e. if a c-tor calls `construct_m1()`; `construct_m2()`; `construct_m3()`;
-// Matchers will be called as follows: `m1`, `m2`, `m3`
+// Next, GraphRewrite will invoke matchers passes registered in add_matcher order.
+// For example:
+//     ngraph::pass::GraphRewrite pass;
+//     pass.add_matcher<m1>();
+//     pass.add_matcher<m2>();
+//     pass.add_matcher<m3>();
+// Matcher passes will be called as follows: `m1`, `m2`, `m3`
 // Matchers should only replace nodes in the graph that come before the current root
 // node in the topological order. For example, if Matcher matches Neg3, it should only
 // replace nodes `Abs2` and `Constant1` if needed
@@ -47,137 +53,169 @@ using namespace ngraph;
 // and `m2` folds `Neg3(Constant1)` when `m3` is called on `Add4` it will discover that
 // both `Abs2` and `Neg3` were already replaced by constants, so `Add4` will also be folded into
 // one.
-// If any Matcher succeeds the rest of the matchers will **not** be called.
+// If any matcher passes succeeds the rest of the matchers will **not** be called.
 // E.g. if `m1` succeeds and replaces `Abs2` with a new constant, nor `m2` or `m3` will be called
 // However, sometimes, you will need more than one fusion occur on the same node.
-// In this case, you should be able to request another pass of GraphRewrite.
-// To request another pass, you will need to register fusions in a callback:
-// i.e. you will need to pass `this` into a callback and then call `this->construct_X`
-// This will schedule another pass of GraphRewrite with the following fusion.
-// This approach should only be used if you are either:
-// a) need more than one fusion occur on the same node
-// b) you are modifying nodes after the current node in the topological order
-// c) there's no linear order of fusions which will give
-//    the correct final fusion. i.e. the same fusion needs to occur before and after some other
-//    fusion
+// In this case, you need to register nodes in MatcherPass manually using register_new_node method.
+// GraphRewrite will automatically add this nodes in the beginning of execution queue.
+// If MatcherPass register more than one node make sure that this nodes are registered in
+// topological order.
 
 bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
 {
     bool rewritten = false;
-    const size_t NUM_TRIES = 10;
-    size_t tries = NUM_TRIES;
-    vector<MatchClosure> original_matchers{m_matchers};
-    // This check is very expensive and is only needed for experimental features, so we will hide
-    // it behind an environment variable for now. TODO: Find a less expensive way to handle this.
-    static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
-    bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
-    do
+
+    // Initialize execution queue with nodes in topological order
+    deque<std::shared_ptr<Node>> nodes_to_run;
+    for (auto& node : f->get_ordered_ops())
     {
-        rewritten = false;
-        // m_matchers may contain newly constructed matchers for matchers
-        // that need multiple passes. See comments above.
-        vector<MatchClosure> matchers_to_run{m_matchers};
-        m_matchers.clear();
-        for (auto node : f->get_ordered_ops())
+        nodes_to_run.emplace_back(node);
+    }
+
+    // Check that all Matchers in MatcherPasses has type bases root node
+    bool all_roots_has_type = true;
+    std::unordered_map<NodeTypeInfo, std::vector<std::shared_ptr<MatcherPass>>> type_to_matcher;
+    for (auto& m : m_matchers)
+    {
+        auto matcher = m->get_matcher();
+        if (!matcher)
+        {
+            all_roots_has_type = false;
+            break;
+        }
+
+        auto root = matcher->get_pattern_value().get_node_shared_ptr();
+        // pattern::op::AnyOutput operation automatically appends for multi output operations inside
+        // Matcher and to gen actual root node we need to take it's parent.
+        if (auto any_type = dynamic_pointer_cast<pattern::op::AnyOutput>(root))
         {
-            if (m_enable_shape_inference)
+            root = any_type->input_value(0).get_node_shared_ptr();
+        }
+
+        // if root is an operation from opset or has pattern::op::WrapType type then we can extract
+        // it's type
+        // and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
+        // and default algorithm is used.
+        NodeTypeInfo root_type_info = root->get_type_info();
+        if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root))
+        {
+            if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p))
             {
-                node->revalidate_and_infer_types();
+                root_type_info = any_type->get_wrapped_type();
             }
-            for (auto& closure : matchers_to_run)
+            else
             {
-                if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
-                {
-                    NGRAPH_DEBUG << "matcher callback requires static shape but the "
-                                    "function is dynamic, skipping this "
-                                    "optimization till the shapes are fully "
-                                    "materialized";
-                    continue;
-                }
-                if (closure.handler(node))
-                {
-                    rewritten = true;
-                    // If call back may change function's is_dynamic state, we need to
-                    // update the cached value.
-                    if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
-                    {
-                        is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
-                    }
-                    break;
-                }
+                all_roots_has_type = false;
+                break;
             }
         }
-
-    } while (rewritten && m_matchers.size() > 0 && tries--);
-
-    m_matchers.assign(original_matchers.begin(), original_matchers.end());
-    return (NUM_TRIES - tries) > 1; // this means a graph was transformed
-}
-
-static vector<regex> initialize_fusion_regexes()
-{
-    static const string nsf = getenv_string("NGRAPH_DISABLED_FUSIONS");
-    vector<regex> regexes;
-    if (!nsf.empty())
-    {
-        const auto sregexes = split(nsf, ';');
-
-        transform(sregexes.begin(),
-                  sregexes.end(),
-                  back_inserter(regexes),
-                  [](const string& c) -> regex { return regex(c); });
+        type_to_matcher[root_type_info].push_back(m);
     }
-    return regexes;
-}
 
-bool pass::GraphRewriteBase::is_enabled(const std::string& name) const
-{
-    // note, regexes are static to avoid re-initialization
-    static const auto regexes = initialize_fusion_regexes();
-
-    for (const auto& regex : regexes)
-    {
-        if (regex_match(name, regex))
+    // This lambda preforms execution of particular MatcherPass on given node.
+    // It automatically handles nodes registered by MatcherPass during transformation and set
+    // transformation callback.
+    auto run_matcher_pass = [&](std::shared_ptr<MatcherPass> m_pass,
+                                std::shared_ptr<Node> node) -> bool {
+        // Keep this property check for backward compatibility. In future transformation property
+        // will be deprecated and removed.
+        if (m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && f->is_dynamic())
         {
-            NGRAPH_DEBUG << "Disabling matcher " << name;
+            NGRAPH_DEBUG << "matcher callback requires static shape but the "
+                            "function is dynamic, skipping this "
+                            "optimization till the shapes are fully "
+                            "materialized";
             return false;
         }
-    }
 
-    return true;
-}
+        if (!m_has_default_callback)
+        {
+            m_pass->set_callback(m_transformation_callback);
+        }
 
-void pass::GraphRewriteBase::add_handler(const std::string& name,
-                                         function<bool(const std::shared_ptr<Node>&)> handler,
-                                         const PassPropertyMask& property)
-{
-    if (is_enabled(name))
+        // Apply MatcherPass. In case if it returns true no other MatcherPasses will apply
+        // to this node
+        bool status = m_pass->apply(node);
+
+        // In case if MatcherPass registered nodes they will be added to the beginning of execution
+        // queue
+        const auto& new_nodes = m_pass->get_new_nodes();
+        if (!new_nodes.empty())
+        {
+            // Need to push nodes in reverse order as we expect that nodes in new_nodes
+            // vector are in topological order
+            for (auto it = new_nodes.rbegin(); it != new_nodes.rend(); it++)
+            {
+                nodes_to_run.emplace_front(*it);
+            }
+            m_pass->clear_new_nodes();
+        }
+        return status;
+    };
+
+    while (!nodes_to_run.empty())
     {
-        m_matchers.push_back({name, handler, property});
-        // If any matcher call back may change dynamic state, we need to
-        // update the pass property.
-        if (property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
+        auto node = nodes_to_run.front();
+        nodes_to_run.pop_front();
+        // Temporary keep this GraphRewrite property for backward compatibility
+        if (m_enable_shape_inference)
+        {
+            node->revalidate_and_infer_types();
+        }
+        // If all Matchers in MatcherPasses has type based root node then we apply efficient
+        // algorithm for finding matchers
+        if (all_roots_has_type)
         {
-            set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
+            auto node_type_info = node->get_type_info();
+            if (type_to_matcher.count(node_type_info))
+            {
+                for (auto& m_pass : type_to_matcher[node_type_info])
+                {
+                    if (run_matcher_pass(m_pass, node))
+                    {
+                        rewritten = true;
+                        break;
+                    }
+                }
+            }
+        }
+        // Otherwise we use default algorithm that iterates over all registered matcher passes
+        else
+        {
+            for (auto& m_pass : m_matchers)
+            {
+                if (run_matcher_pass(m_pass, node))
+                {
+                    rewritten = true;
+                    break;
+                }
+            }
         }
     }
+    return rewritten;
 }
 
 void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
                                      const graph_rewrite_callback& callback,
                                      const PassPropertyMask& property)
 {
-    add_handler(m->get_name(),
-                [m, callback](const std::shared_ptr<Node>& node) -> bool {
-                    NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
-                    if (m->match(node->output(0)))
-                    {
-                        NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
-                        return callback(*m.get());
-                    }
-                    return false;
-                },
-                property);
+    m_matchers.push_back(std::make_shared<MatcherPass>(
+        m->get_name(),
+        m,
+        [m, callback](const std::shared_ptr<Node>& node) -> bool {
+            NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
+            if (m->match(node->output(0)))
+            {
+                NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
+                bool status = callback(*m.get());
+                // explicitly clear Matcher state because it holds pointers to matched nodes
+                m->clear_state();
+                return status;
+            }
+            m->clear_state();
+            return false;
+        },
+        property));
 }
 
 void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
@@ -193,17 +231,19 @@ void pass::RecurrentGraphRewrite::add_matcher(
     const ngraph::recurrent_graph_rewrite_callback& callback,
     const PassPropertyMask& property)
 {
-    add_handler("Reurrent matcher",
-                [m, callback](const std::shared_ptr<Node>& node) {
-                    NGRAPH_DEBUG << "Running recurrent matcher on " << node;
-                    if (m->match(node->output(0)))
-                    {
-                        NGRAPH_DEBUG << "Recurrent matcher matched " << m.get();
-                        return callback(*m.get());
-                    }
-                    return false;
-                },
-                property);
+    m_matchers.push_back(std::make_shared<MatcherPass>(
+        "Recurrent matcher",
+        nullptr,
+        [m, callback](const std::shared_ptr<Node>& node) {
+            NGRAPH_DEBUG << "Running recurrent matcher on " << node;
+            if (m->match(node->output(0)))
+            {
+                NGRAPH_DEBUG << "Recurrent matcher matched " << m.get();
+                return callback(*m.get());
+            }
+            return false;
+        },
+        property));
 }
 
 void pass::RecurrentGraphRewrite::add_matcher(
@@ -228,9 +268,9 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
         bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
         for (auto node : f->get_ops())
         {
-            for (auto& closure : m_matchers)
+            for (auto& m_pass : m_matchers)
             {
-                if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
+                if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE))
                 {
                     NGRAPH_DEBUG << "matcher callback requires static shape but the "
                                     "function is dynamic, skipping this "
@@ -238,11 +278,11 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
                                     "materialized";
                     continue;
                 }
-                if (closure.handler(node))
+                if (m_pass->apply(node))
                 {
                     // If call back may change function's is_dynamic state, we need to
                     // update the cached value.
-                    if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
+                    if (m_pass->get_property(PassProperty::CHANGE_DYNAMIC_STATE))
                     {
                         is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
                     }
@@ -260,3 +300,30 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
     } while (changed && i < m_num_iters);
     return changed;
 }
+
+void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::pattern::Matcher>& m,
+                                                 const ngraph::graph_rewrite_callback& callback,
+                                                 const PassPropertyMask& property)
+{
+    set_name(m->get_name());
+    set_property(property, true);
+    m_matcher = m;
+    m_handler = [m, callback](const std::shared_ptr<Node>& node) -> bool {
+        if (m->match(node->output(0)))
+        {
+            NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
+            bool status = callback(*m.get());
+            // explicitly clear Matcher state because it holds pointers to matched nodes
+            m->clear_state();
+            return status;
+        }
+        m->clear_state();
+        return false;
+    };
+}
+
+bool ngraph::pass::MatcherPass::apply(std::shared_ptr<ngraph::Node> node)
+{
+    m_new_nodes.clear();
+    return m_handler(node);
+}
\ No newline at end of file
index 0e2ecde..821a8a9 100644 (file)
@@ -27,80 +27,132 @@ namespace ngraph
 {
     namespace pass
     {
-        class GraphRewriteBase;
         class GraphRewrite;
         class RecurrentGraphRewrite;
+        class MatcherPass;
     }
 
+    using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
     using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
     using recurrent_graph_rewrite_callback =
         std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
+    using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
 }
 
-class NGRAPH_API ngraph::pass::GraphRewriteBase : public ngraph::pass::FunctionPass
+/// \brief MatcherPass is a basic block for pattern based transformations. It describes pattern and
+/// action that is applied if pattern is matched.
+///
+/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented and
+/// finally registered by using \sa register_matcher. MatcherPass can be executed on node within
+/// \sa apply method. To run matcher pass on Function use GraphRewrite.
+/// In addition MatcherPass provides a way for adding new operations into GraphRewrite execution
+/// queue. That means that operations that were created inside transformation callback can be added
+/// for matching. To register node use \sa register_new_node method. GraphRewrite automatically
+/// takes registered nodes and put them to execution queue. If multiple nodes were register make
+/// sure that they were registered in topological order.
+/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
+/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
+/// efficient.
+
+class NGRAPH_API ngraph::pass::MatcherPass : public ngraph::pass::PassBase
 {
 public:
-    /// \brief Add an arbitrary handler for nodes
-    /// \param name The name of the handler
-    /// \param handler Function responsible for deciding if the graph should be changed and making
-    /// the changes. Returns true if changes are made.
-    void add_handler(const std::string& name,
-                     std::function<bool(const std::shared_ptr<Node>& node)> handler,
-                     const PassPropertyMask& property);
-
-protected:
-    GraphRewriteBase()
-        : FunctionPass()
+    MatcherPass() = default;
+
+    MatcherPass(const MatcherPass&) = delete;
+    MatcherPass& operator=(const MatcherPass&) = delete;
+
+    explicit MatcherPass(const std::string& name,
+                         const std::shared_ptr<pattern::Matcher>& m,
+                         const handler_callback& handler,
+                         const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
+        : PassBase()
+        , m_handler(handler)
+        , m_matcher(m)
     {
-        // Being explicit:
-        // Setting REQUIRE_STATIC_SHAPE to false because we will check if each
-        // callback needs static shape during run_on_function().
-        set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
+        set_name(name);
+        set_property(property, true);
     }
 
-    bool is_enabled(const std::string& name) const;
+    bool apply(std::shared_ptr<ngraph::Node> node);
 
-    struct MatchClosure
+    template <typename T, class... Args>
+    std::shared_ptr<T> register_new_node(Args&&... args)
     {
-        std::string name;
-        std::function<bool(const std::shared_ptr<Node>& node)> handler;
-        PassPropertyMask property;
-    };
-    std::vector<MatchClosure> m_matchers;
+        auto node = std::make_shared<T>(std::forward<Args>(args)...);
+        m_new_nodes.push_back(node);
+        return node;
+    }
+
+    const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() { return m_new_nodes; }
+    void clear_new_nodes() { m_new_nodes.clear(); }
+    std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
+protected:
+    void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
+                          const ngraph::graph_rewrite_callback& callback,
+                          const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
+
+private:
+    handler_callback m_handler;
+    std::shared_ptr<pattern::Matcher> m_matcher;
+    std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
 };
 
-/// \brief GraphRewrite (in tandem with \sa Matcher) performs transformations on specified patterns
+/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function in
+/// efficient way
 ///
-/// Graph rewrite pass essentially allows pass users to rewrite parts of the
-/// input graph in any way they want. Fusion is one example of graph rewrite that
-/// fuses multiple ops together. At a high-level users of the pass need to
-/// specify 2 things: 1) which ops to fuse (via \sa Matcher, and 2) how to create new op(s) from
-/// the existing ops by providing a callback to \p Matcher object
-/// Patterns can be added by using \sa add_matcher
-/// Callbacks should use \sa replace_node to transform matched sub graphs
-
-class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::GraphRewriteBase
+/// Graph rewrite pass is used for matcher passes execution on Function.
+/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass class.
+/// As a default algorithm graph rewrite pass traverse Function in topological order and applies
+/// registered matcher passes for each node. But if all registered matcher passes have type based
+/// root node in Matcher pattern then efficient mechanism is used to execute them.
+/// Matcher pattern root is type based if it's operation from opset or pattern::op::WrapType.
+/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
+/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
+/// efficient.
+
+class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::FunctionPass
 {
 public:
+    GraphRewrite() = default;
+
+    explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
+        : FunctionPass()
+    {
+        m_matchers.push_back(pass);
+    }
+
+    template <typename T, class... Args>
+    std::shared_ptr<T> add_matcher(Args&&... args)
+    {
+        static_assert(std::is_base_of<pass::MatcherPass, T>::value,
+                      "pass not derived from MatcherPass");
+        auto pass = std::make_shared<T>(std::forward<Args>(args)...);
+        m_matchers.push_back(pass);
+        return pass;
+    }
+
     void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
                      const ngraph::graph_rewrite_callback& callback,
-                     const PassPropertyMask& property);
+                     const PassPropertyMask& property) NGRAPH_DEPRECATED("Use MatcherPass instead");
 
-    // TODO: This interface may deprecate after all passes are refactored.
     void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
-                     const ngraph::graph_rewrite_callback& callback);
+                     const ngraph::graph_rewrite_callback& callback)
+        NGRAPH_DEPRECATED("Use MatcherPass instead");
 
-    virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
+    bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
 
 protected:
     bool m_enable_shape_inference = false;
+
+    std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
 };
 
-class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::GraphRewriteBase
+class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::FunctionPass
 {
 public:
     RecurrentGraphRewrite(size_t num_iters = 10)
-        : GraphRewriteBase()
+        : FunctionPass()
         , m_num_iters(num_iters)
     {
     }
@@ -117,4 +169,6 @@ public:
 
 private:
     size_t m_num_iters;
+
+    std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
 };
index 3b8cdbe..818b3a4 100644 (file)
 //*****************************************************************************
 
 #include <algorithm>
-#ifdef _WIN32
-#else
-#include <cxxabi.h>
-#endif
 #include <iomanip>
 #include <iostream>
 #include <memory>
@@ -26,7 +22,9 @@
 #include "ngraph/env_util.hpp"
 #include "ngraph/function.hpp"
 #include "ngraph/graph_util.hpp"
+#include "ngraph/log.hpp"
 #include "ngraph/node.hpp"
+#include "ngraph/pass/graph_rewrite.hpp"
 #include "ngraph/pass/manager.hpp"
 #include "ngraph/pass/pass.hpp"
 #include "ngraph/pass/serialize.hpp"
@@ -52,80 +50,93 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool /* transitive */)
     static bool profile_enabled = getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
 
     get_state().set_function(func);
-    vector<std::pair<shared_ptr<Function>, bool>> fs{std::make_pair(func, func->is_dynamic())};
     vector<shared_ptr<Function>> f_array{func};
 
     size_t index = 0;
     stopwatch pass_timer;
     stopwatch overall_timer;
     overall_timer.start();
-    for (shared_ptr<PassBase> pass : m_pass_list)
+    bool function_changed = false;
+    for (auto& pass : m_pass_list)
     {
         pass_timer.start();
         pass->set_state(get_state());
-        auto module_pass = dynamic_pointer_cast<ModulePass>(pass);
-        auto function_pass = dynamic_pointer_cast<FunctionPass>(pass);
-        auto node_pass = dynamic_pointer_cast<NodePass>(pass);
-        auto call_graph_pass = dynamic_pointer_cast<CallGraphPass>(pass);
-        if (module_pass)
+        if (!m_has_default_callback)
+        {
+            pass->set_callback(m_transformation_callback);
+        }
+
+        if (auto module_pass = dynamic_pointer_cast<ModulePass>(pass))
         {
             if (auto vt_pass = dynamic_pointer_cast<pass::VisualizeTree>(module_pass))
             {
                 vt_pass->set_ops_to_details(get_state().get_visualize_tree_ops_map());
             }
-            module_pass->run_on_module(f_array);
+            function_changed = module_pass->run_on_module(f_array);
         }
-        else if (function_pass)
+        else if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass))
         {
-            for (auto f_pair : fs)
+            // This checks is to skip the graph transformation when the graph pass relies on
+            // static shape but the function state is dynamic.
+            if (matcher_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
+                func->is_dynamic())
             {
-                shared_ptr<Function> f = f_pair.first;
-                // This checks is to skip the graph optimization when the graph pass relies on
-                // static shape but the function state is dynamic.
-                // we update the function dynamic state only if we run the graph pass successfully.
-                if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
-                    f_pair.second)
-                {
-                    continue;
-                }
-                bool function_modified = function_pass->run_on_function(f);
-                // If the pass may change the function's is_dynamic property, we need to
-                // update the cached value.
-                if (function_modified &&
-                    function_pass->get_property(PassProperty::CHANGE_DYNAMIC_STATE))
-                {
-                    f_pair.second = f->is_dynamic();
-                }
+                NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+                             << "function is dynamic. Skipping this transformation";
+                continue;
             }
+            // GraphRewrite is a temporary container for MatcherPass to make execution
+            // on on entire ngraph::Function
+            function_changed = GraphRewrite(matcher_pass).run_on_function(func);
         }
-        else if (node_pass)
+        else if (auto function_pass = dynamic_pointer_cast<FunctionPass>(pass))
         {
-            for (auto f_pair : fs)
+            // This checks is to skip the graph transformation when the graph pass relies on
+            // static shape but the function state is dynamic.
+            if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
+                func->is_dynamic())
             {
-                shared_ptr<Function> f = f_pair.first;
-                if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && f_pair.second)
-                {
-                    continue;
-                }
-                for (shared_ptr<Node> n : f->get_ops())
+                NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+                             << "function is dynamic. Skipping this transformation";
+                continue;
+            }
+
+            if (dynamic_pointer_cast<Validate>(pass))
+            {
+                if (function_changed)
                 {
-                    node_pass->run_on_node(n);
+                    function_pass->run_on_function(func);
+                    function_changed = false;
                 }
             }
+            else
+            {
+                function_changed = function_pass->run_on_function(func);
+            }
         }
-        else if (call_graph_pass)
+        else if (auto node_pass = dynamic_pointer_cast<NodePass>(pass))
         {
-            for (auto f_pair : fs)
+            if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic())
             {
-                shared_ptr<Function> f = f_pair.first;
-                if (call_graph_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
-                    f_pair.second)
-                {
-                    continue;
-                }
-                bool function_modified = call_graph_pass->run_on_call_graph(f->get_ordered_ops());
-                f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
+                NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+                             << "function is dynamic. Skipping this transformation";
+                continue;
+            }
+            for (shared_ptr<Node> n : func->get_ops())
+            {
+                function_changed |= node_pass->run_on_node(n);
+            }
+        }
+        else if (auto call_graph_pass = dynamic_pointer_cast<CallGraphPass>(pass))
+        {
+            if (call_graph_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
+                func->is_dynamic())
+            {
+                NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+                             << "function is dynamic. Skipping this transformation";
+                continue;
             }
+            function_changed = call_graph_pass->run_on_call_graph(func->get_ordered_ops());
         }
 
         if (m_visualize || m_serialize)
@@ -135,7 +146,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool /* transitive */)
             std::string index_str = std::to_string(index);
             index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
             auto base_filename = f_array.at(0)->get_name() + std::string("_") + index_str +
-                                 std::string("_") + m_pass_names.at(index);
+                                 std::string("_") + pass->get_name();
 
             if (m_visualize)
             {
@@ -156,13 +167,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool /* transitive */)
         pass_timer.stop();
         if (profile_enabled)
         {
-            PassBase* p = pass.get();
-            string name = typeid(*p).name();
-#ifndef _WIN32
-            int status;
-            name = abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status);
-#endif
-            cout << setw(7) << pass_timer.get_milliseconds() << "ms " << name << "\n";
+            cout << setw(7) << pass_timer.get_milliseconds() << "ms " << pass->get_name() << "\n";
         }
     }
     if (profile_enabled)
index 565600a..98323c6 100644 (file)
@@ -63,6 +63,29 @@ public:
     /// each registered pass
     /// \param new_state Value "true" enables Validate pass run; "false", otherwise
     void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
+    /// \brief Callback is a lambda function that can be used by registered transformations.
+    /// The main purpose of this callback is to provide a way for plugins to disable/enable
+    /// transformations. In some cases plugins may want not to execute some transformations.
+    /// For example plugin can disable unpleasant decompositions because of performance reasons.
+    /// Callback example:
+    /// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+    ///     return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+    /// };
+    /// This callback returns true in case of DepthToSpace operation. So when execution DepthToSpace
+    /// decomposition pass will check is this decomposition needed or plugin can execute this
+    /// operation directly. And of course on transformation side we need to have a response for this
+    /// callback.
+    /// if (m_transformation_callback(batch_to_space)) {
+    ///     return false;
+    /// }
+    /// \param callback lamda function that returns true in case if node is supported by plugin and
+    /// transformation is not needed
+    void set_callback(param_callback callback)
+    {
+        m_transformation_callback = callback;
+        m_has_default_callback = false;
+    }
+
 private:
     template <typename T, class... Args>
     std::shared_ptr<T> push_pass(Args&&... args)
@@ -71,23 +94,14 @@ private:
         auto pass = std::make_shared<T>(std::forward<Args>(args)...);
         auto pass_base = std::static_pointer_cast<PassBase>(pass);
         m_pass_list.push_back(pass_base);
-        if (m_visualize || m_serialize)
-        {
-#ifdef _WIN32
-            // MSVC produce a human-readable type name like class ngraph::pass::LikeReplacement
-            // by typeid(T).name(). Later ofstream doesn't accept it as a valid file name.
-            //
-            std::string str = typeid(T).name();
-            auto pos = str.find_last_of(":");
-            m_pass_names.push_back(str.substr(pos + 1));
-#elif defined(__linux) || defined(__APPLE__)
-            m_pass_names.push_back(typeid(T).name());
-#endif
-        }
         return pass;
     }
 
-    std::vector<std::string> m_pass_names;
+    param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
+        return false;
+    };
+    bool m_has_default_callback = true;
+
     std::vector<std::shared_ptr<PassBase>> m_pass_list;
     ManagerState m_state;
     PassConfig m_pass_config;
index da0720e..0d796db 100644 (file)
 // limitations under the License.
 //*****************************************************************************
 
-#include "ngraph/pass/pass.hpp"
+#ifdef _WIN32
+#else
+#include <cxxabi.h>
+#endif
+
 #include "ngraph/pass/manager.hpp"
+#include "ngraph/pass/pass.hpp"
 
 using namespace std;
 using namespace ngraph;
@@ -52,6 +57,30 @@ void pass::PassBase::set_property(const PassPropertyMask& prop, bool value)
     }
 }
 
+std::string pass::PassBase::get_name() const
+{
+    if (m_name.empty())
+    {
+        const PassBase* p = this;
+        std::string pass_name = typeid(*p).name();
+#ifndef _WIN32
+        int status;
+        pass_name = abi::__cxa_demangle(pass_name.c_str(), nullptr, nullptr, &status);
+#endif
+        return pass_name;
+    }
+    else
+    {
+        return m_name;
+    }
+}
+
+void pass::PassBase::set_callback(const param_callback& callback)
+{
+    m_transformation_callback = callback;
+    m_has_default_callback = false;
+}
+
 // The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
 
 pass::ModulePass::~ModulePass()
index a55d275..b9bd6e6 100644 (file)
@@ -20,6 +20,7 @@
 #include <memory>
 #include <vector>
 
+#include "ngraph/deprecated.hpp"
 #include "ngraph/function.hpp"
 #include "ngraph/node.hpp"
 #include "ngraph/pass/manager_state.hpp"
@@ -32,8 +33,8 @@ namespace ngraph
         class PassBase;
         class ModulePass;
         class FunctionPass;
-        class NodePass;
-        class CallGraphPass;
+        class NodePass NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.");
+        class CallGraphPass NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.");
         class Manager;
         enum class FusionType : uint32_t
         {
@@ -53,8 +54,10 @@ namespace ngraph
             // Pass requires node shapes to be static
             REQUIRE_STATIC_SHAPE = 0x1,
             // Pass transformation will change the function's dynamic state
-            CHANGE_DYNAMIC_STATE = 1 << 1
+            CHANGE_DYNAMIC_STATE = 1 << 1,
         };
+
+        using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
     }
 }
 
@@ -79,14 +82,25 @@ public:
     /// Check if this pass has all the pass properties.
     bool get_property(const PassPropertyMask& prop_mask) const;
 
+    void set_name(const std::string& name) { m_name = name; }
+    std::string get_name() const;
+
+    void set_callback(const param_callback& callback);
+
 protected:
     ManagerState& get_state();
     void set_state(ManagerState&);
     void set_property(const PassPropertyMask& prop, bool value);
 
+    param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
+        return false;
+    };
+    bool m_has_default_callback = true;
+
 private:
     PassPropertyMask m_property;
     ManagerState* m_state{nullptr};
+    std::string m_name;
 };
 
 class NGRAPH_API ngraph::pass::ModulePass : public PassBase
index 0d10f2d..af43176 100644 (file)
@@ -227,6 +227,9 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
 
     render();
 
+    // Clean up local variable not to hold node pointers
+    m_nodes_with_attributes.clear();
+
     return false;
 }
 
index 381ef44..d4a2422 100644 (file)
@@ -221,8 +221,6 @@ namespace ngraph
 
         bool Matcher::match(const Output<Node>& graph_value)
         {
-            // clear our state
-            m_matched_list.clear();
             return match(graph_value, PatternValueMap{});
         }
 
@@ -230,10 +228,7 @@ namespace ngraph
         bool Matcher::match(const Output<Node>& graph_value,
                             const PatternValueMap& previous_matches)
         {
-            // clear our state
-            m_match_root.reset();
-            m_pattern_map.clear();
-            m_matched_list.clear();
+            clear_state();
 
             // insert previous matches
             m_pattern_map.insert(previous_matches.cbegin(), previous_matches.cend());
@@ -251,6 +246,14 @@ namespace ngraph
             return match(graph_value, as_pattern_value_map(previous_matches));
         }
 
+        void Matcher::clear_state()
+        {
+            m_match_root.reset();
+            m_pattern_map.clear();
+            m_pattern_value_maps.clear();
+            m_matched_list.clear();
+        }
+
         namespace
         {
             std::set<std::shared_ptr<Node>>
index a0ce67e..9ba73f8 100644 (file)
@@ -176,6 +176,8 @@ namespace ngraph
 
             void capture(const std::set<Node*>& static_nodes);
 
+            void clear_state();
+
             size_t get_number_of_recurrent_matches() const { return m_pattern_value_maps.size(); }
             NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
             size_t get_number_of_bound_labels() const;
index a42a34d..fbc9d8a 100644 (file)
@@ -60,3 +60,8 @@ bool pattern::op::Label::match_value(Matcher* matcher,
     }
     return false;
 }
+
+std::shared_ptr<Node> pattern::any_input()
+{
+    return std::make_shared<pattern::op::Label>();
+}
\ No newline at end of file
index 85d4f2e..e621ad0 100644 (file)
@@ -61,7 +61,8 @@ namespace ngraph
                     set_output_type(0, type, s);
                 }
 
-                Label(const element::Type& type, const PartialShape& s)
+                explicit Label(const element::Type& type = element::dynamic,
+                               const PartialShape& s = PartialShape::dynamic())
                     : Label(type, s, [](const Output<Node>&) { return true; }, OutputVector())
                 {
                 }
@@ -141,5 +142,8 @@ namespace ngraph
                 static Output<Node> wrap_values(const OutputVector& wrapped_values);
             };
         }
+
+        NGRAPH_API
+        std::shared_ptr<Node> any_input();
     }
 }
index 6f6223f..87e078c 100644 (file)
@@ -61,5 +61,11 @@ namespace ngraph
             }
             return result;
         }
+
+        std::function<bool(Output<Node>)> consumers_count(size_t n)
+        {
+            return
+                [=](Output<Node> output) -> bool { return output.get_target_inputs().size() == n; };
+        }
     }
 }
index 0ec1e51..ca23d39 100644 (file)
@@ -49,6 +49,9 @@ namespace ngraph
             return pred;
         }
 
+        NGRAPH_API
+        std::function<bool(Output<Node>)> consumers_count(size_t n);
+
         namespace op
         {
             using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
diff --git a/ngraph/src/ngraph/pattern/op/wrap_type.cpp b/ngraph/src/ngraph/pattern/op/wrap_type.cpp
new file mode 100644 (file)
index 0000000..bb639d0
--- /dev/null
@@ -0,0 +1,46 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "ngraph/pattern/op/wrap_type.hpp"
+#include "ngraph/pattern/matcher.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+constexpr NodeTypeInfo pattern::op::WrapType::type_info;
+
+const NodeTypeInfo& pattern::op::WrapType::get_type_info() const
+{
+    return type_info;
+}
+
+bool pattern::op::WrapType::match_value(Matcher* matcher,
+                                        const Output<Node>& pattern_value,
+                                        const Output<Node>& graph_value)
+{
+    if (graph_value.get_node_shared_ptr()->get_type_info() == get_wrapped_type() &&
+        m_predicate(graph_value))
+    {
+        auto& pattern_map = matcher->get_pattern_value_map();
+        pattern_map[shared_from_this()] = graph_value;
+        matcher->add_node(graph_value);
+        return (get_input_size() == 0
+                    ? true
+                    : matcher->match_arguments(pattern_value.get_node(),
+                                               graph_value.get_node_shared_ptr()));
+    }
+    return false;
+}
diff --git a/ngraph/src/ngraph/pattern/op/wrap_type.hpp b/ngraph/src/ngraph/pattern/op/wrap_type.hpp
new file mode 100644 (file)
index 0000000..c65f8a0
--- /dev/null
@@ -0,0 +1,74 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "ngraph/pattern/op/pattern.hpp"
+
+namespace ngraph
+{
+    namespace pattern
+    {
+        namespace op
+        {
+            class NGRAPH_API WrapType : public Pattern
+            {
+            public:
+                static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
+                const NodeTypeInfo& get_type_info() const override;
+
+                explicit WrapType(NodeTypeInfo wrapped_type,
+                                  const ValuePredicate& pred =
+                                      [](const Output<Node>& output) { return true; },
+                                  const OutputVector& input_values = {})
+                    : Pattern(input_values, pred)
+                    , m_wrapped_type(wrapped_type)
+                {
+                    set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
+                }
+
+                bool match_value(pattern::Matcher* matcher,
+                                 const Output<Node>& pattern_value,
+                                 const Output<Node>& graph_value) override;
+
+                NodeTypeInfo get_wrapped_type() const { return m_wrapped_type; }
+            private:
+                NodeTypeInfo m_wrapped_type;
+            };
+        }
+
+        template <class T>
+        std::shared_ptr<Node> wrap_type(const OutputVector& inputs,
+                                        const pattern::op::ValuePredicate& pred)
+        {
+            static_assert(std::is_base_of<Node, T>::value, "Unexpected template type");
+            return std::make_shared<op::WrapType>(T::type_info, pred, inputs);
+        }
+
+        template <class T>
+        std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {})
+        {
+            return wrap_type<T>(inputs, [](const Output<Node>& output) { return true; });
+        }
+
+        template <class T>
+        std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred)
+        {
+            return wrap_type<T>({}, pred);
+        }
+    }
+}
index 537443b..0b01eb0 100644 (file)
@@ -64,10 +64,12 @@ set(SRC
     eval.cpp
     file_util.cpp
     float16.cpp
+    graph_rewrite.cpp
     includes.cpp
     input_output_assign.cpp
     intervals.cpp
     main.cpp
+    matcher_pass.cpp
     misc.cpp
     ngraph_api.cpp
     node_input_output.cpp
index 4ff3c63..5c8b022 100644 (file)
@@ -2789,13 +2789,6 @@ TEST(constant_folding, constant_tile_0_rank_data)
     ASSERT_EQ(values_expected, values_out);
 }
 
-TEST(constant_folding, pass_property)
-{
-    auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
-    ASSERT_FALSE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
-    ASSERT_TRUE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
-}
-
 TEST(constant_folding, constant_non_zero_0D)
 {
     auto data = op::Constant::create(element::i32, Shape{}, {1});
diff --git a/ngraph/test/graph_rewrite.cpp b/ngraph/test/graph_rewrite.cpp
new file mode 100644 (file)
index 0000000..2f12e0e
--- /dev/null
@@ -0,0 +1,119 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <util/test_tools.hpp>
+
+using namespace ::testing;
+using namespace std;
+using namespace ngraph;
+
+class TestPass : public ngraph::pass::MatcherPass
+{
+public:
+    TestPass()
+        : MatcherPass()
+    {
+        auto divide = std::make_shared<ngraph::pattern::op::Label>(
+            element::f32, Shape{}, pattern::has_class<opset3::Divide>());
+        ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
+            if (m_transformation_callback(m.get_match_root()))
+            {
+                auto relu =
+                    std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
+                ngraph::replace_node(m.get_match_root(), relu);
+                return true;
+            }
+            return false;
+        };
+
+        auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
+        this->register_matcher(m, callback);
+    }
+};
+
+class Anchor : public ngraph::pass::GraphRewrite
+{
+public:
+    Anchor()
+        : GraphRewrite()
+    {
+    }
+};
+
+std::shared_ptr<Function> get_function()
+{
+    auto data =
+        std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
+    auto divide_constant =
+        ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
+    auto divide = std::make_shared<ngraph::opset3::Divide>(data, divide_constant);
+    return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
+                                              ngraph::ParameterVector{data});
+}
+
+ngraph::pass::param_callback get_callback()
+{
+    return [](const std::shared_ptr<const Node>& node) -> bool {
+        if (std::dynamic_pointer_cast<const opset3::Divide>(node))
+        {
+            return true;
+        }
+        else
+        {
+            return false;
+        }
+    };
+}
+
+TEST(GraphRewriteTest, MatcherPassCallback)
+{
+    auto f = get_function();
+
+    Anchor anchor;
+    anchor.add_matcher<TestPass>()->set_callback(get_callback());
+    anchor.run_on_function(f);
+
+    ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
+TEST(GraphRewriteTest, GraphRewriteCallback)
+{
+    auto f = get_function();
+
+    Anchor anchor;
+    anchor.add_matcher<TestPass>();
+    anchor.set_callback(get_callback());
+    anchor.run_on_function(f);
+
+    ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
+TEST(GraphRewriteTest, ManagerCallback)
+{
+    auto f = get_function();
+
+    pass::Manager manager;
+    auto anchor = manager.register_pass<Anchor>();
+    anchor->add_matcher<TestPass>();
+    manager.set_callback(get_callback());
+    manager.run_passes(f);
+
+    ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
+TEST(GraphRewriteTest, ManagerCallback2)
+{
+    auto f = get_function();
+
+    pass::Manager manager;
+    auto anchor = manager.register_pass<TestPass>();
+    manager.set_callback(get_callback());
+    manager.run_passes(f);
+
+    ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
diff --git a/ngraph/test/matcher_pass.cpp b/ngraph/test/matcher_pass.cpp
new file mode 100644 (file)
index 0000000..be2dadf
--- /dev/null
@@ -0,0 +1,120 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <algorithm>
+#include <cstdio>
+#include <iostream>
+#include <list>
+#include <memory>
+#include <ngraph/pattern/op/wrap_type.hpp>
+#include <ngraph/rt_info.hpp>
+
+#include "gtest/gtest.h"
+#include "ngraph/graph_util.hpp"
+#include "ngraph/log.hpp"
+#include "ngraph/ngraph.hpp"
+#include "ngraph/opsets/opset3.hpp"
+#include "ngraph/pass/graph_rewrite.hpp"
+#include "ngraph/pass/manager.hpp"
+
+using namespace ngraph;
+using namespace std;
+
+class TestMatcherPass : public pass::MatcherPass
+{
+public:
+    TestMatcherPass()
+    {
+        auto m_relu1 =
+            ngraph::pattern::wrap_type<ngraph::opset3::Relu>(pattern::consumers_count(1));
+        auto m_relu2 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>({m_relu1});
+
+        ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
+            // Map that helps to connect labels with matched outputs
+            auto& node_to_output = m.get_pattern_value_map();
+
+            // Create new Relu operation and add register it for additional execution
+            auto new_relu = register_new_node<ngraph::opset3::Relu>(
+                node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));
+
+            // Copy runtime info attributes to newly created operation
+            ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu);
+
+            // Save last Relu name to new Relu operation
+            new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());
+
+            // Replace Relu->Relu with Relu
+            ngraph::replace_node(m.get_match_root(), new_relu);
+
+            // Return true as the root node was changed
+            return true;
+        };
+
+        // Register pattern with Divide operation as a pattern root node
+        auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
+        // Register Matcher
+        this->register_matcher(m, callback);
+    }
+};
+
+TEST(pattern, matcher_pass)
+{
+    {
+        TestMatcherPass test_matcher;
+        auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
+        auto b = make_shared<opset3::Relu>(a);
+        auto c = make_shared<opset3::Relu>(b);
+        auto f = std::make_shared<Function>(ngraph::NodeVector{c}, ParameterVector{a});
+
+        ASSERT_TRUE(test_matcher.get_matcher()->match(c->output(0)));
+        ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().size() == 2);
+        test_matcher.get_matcher()->clear_state();
+        ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().empty());
+
+        test_matcher.apply(c);
+        ASSERT_TRUE(test_matcher.get_new_nodes().size() == 1);
+        test_matcher.apply(test_matcher.get_new_nodes()[0]);
+        ASSERT_TRUE(test_matcher.get_new_nodes().empty());
+    }
+
+    {
+        TestMatcherPass test_matcher;
+        auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
+        auto b = make_shared<opset3::Relu>(a);
+        auto c = make_shared<opset3::Relu>(b);
+        auto f = std::make_shared<Function>(ngraph::NodeVector{b, c}, ParameterVector{a});
+
+        ASSERT_FALSE(test_matcher.get_matcher()->match(c->output(0)));
+    }
+
+    {
+        std::shared_ptr<Function> f;
+        {
+            auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
+            auto b = make_shared<opset3::Relu>(a);
+            auto c = make_shared<opset3::Relu>(b);
+            auto d = make_shared<opset3::Relu>(c);
+            f = std::make_shared<Function>(ngraph::NodeVector{d}, ParameterVector{a});
+        }
+
+        pass::GraphRewrite pass;
+        pass.add_matcher<TestMatcherPass>();
+        pass.run_on_function(f);
+
+        // Parameter->Relu->Result
+        ASSERT_TRUE(f->get_ops().size() == 3);
+    }
+}
\ No newline at end of file
index d63c887..b542775 100644 (file)
@@ -19,6 +19,7 @@
 #include <iostream>
 #include <list>
 #include <memory>
+#include <ngraph/pattern/op/wrap_type.hpp>
 
 #include "gtest/gtest.h"
 #include "ngraph/file_util.hpp"
@@ -576,7 +577,7 @@ TEST(pattern, recurrent_pattern)
     std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
     RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
     ASSERT_TRUE(rm.match(add3));
-    ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
+    ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
     auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
     ASSERT_EQ(recurrent_matches.at(0), add2);
     ASSERT_EQ(recurrent_matches.at(1), add1);
@@ -590,7 +591,7 @@ TEST(pattern, recurrent_pattern)
     auto padd2 = iconst_label + rpattern;
     RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
     ASSERT_TRUE(rm2.match(add3_2));
-    ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
+    ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
     recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
     ASSERT_EQ(recurrent_matches.at(0), add2_2);
     ASSERT_EQ(recurrent_matches.at(1), add1);
@@ -605,7 +606,7 @@ TEST(pattern, recurrent_pattern)
     correlated_matches.insert(iconst_label);
     RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
     ASSERT_TRUE(rm3.match(add3_2));
-    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
+    ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
     iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
     ASSERT_EQ(iconst_matches.size(), 1);
     ASSERT_EQ(iconst_matches.at(0), iconst0);
@@ -613,7 +614,7 @@ TEST(pattern, recurrent_pattern)
     // Matching correlated labels and
     // testing if RecurrentMatcher can be reused for different nodes
     ASSERT_TRUE(rm3.match(add3));
-    ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
+    ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
     recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
     ASSERT_EQ(recurrent_matches.at(0), add2);
     ASSERT_EQ(recurrent_matches.at(1), add1);
@@ -763,3 +764,46 @@ TEST(pattern, is_contained_match)
     ASSERT_TRUE(n.match(label_abs2, absn2));
     ASSERT_FALSE(n.is_contained_match());
 }
+
+TEST(pattern, wrap_type)
+{
+    auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
+    auto b = make_shared<op::Abs>(a);
+    auto c = make_shared<op::Relu>(a);
+    auto mul1 = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
+    auto mul2 = make_shared<op::v1::Multiply>(op::Constant::create(element::f32, Shape{}, {1}), a);
+
+    {
+        auto m = pattern::wrap_type<op::Abs>();
+        auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
+        ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
+        ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
+        ASSERT_EQ(matcher->get_matched_nodes()[0], b);
+        ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
+        ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
+    }
+    {
+        auto m1 = pattern::wrap_type<op::Parameter>();
+        auto m2 = pattern::wrap_type<op::Abs>({m1});
+        auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
+        ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
+        ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
+        ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
+        ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
+        ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
+    }
+    {
+        auto m1 = pattern::wrap_type<op::v1::Multiply>(
+            {pattern::any_input(), pattern::wrap_type<op::Constant>()});
+        auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
+        ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
+        ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
+    }
+    {
+        auto m1 = pattern::wrap_type<op::v1::Multiply>(
+            {pattern::wrap_type<op::Constant>(), pattern::any_input()});
+        auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
+        ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
+        ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
+    }
+}
\ No newline at end of file