This guide contains all necessary information that could help you to start writing nGraph transformations.
-First of all before writing transformation make sure that there is no transformation with the same functionality in [Transformation Library](group__ie__transformation__api.html).
-To start writing transformation it's good to know how [Transformation Library](group__ie__transformation__api.html) is structured, how transformations are organized and where to put your transformation code.
+First of all before writing transformation make sure that there is no transformation with the same functionality
+in [Transformation Library](group__ie__transformation__api.html). To start writing transformation it's good to know
+how [Transformation Library](group__ie__transformation__api.html) is structured, how transformations are organized
+and where to put your transformation code.
Let's start from reviewing transformations library structure.
-Transformations library is independent from InferenceEngine target library named as `inference_engine_transformations` and located in `inference-engine/src/transformations` directory.
+Transformations library is independent from InferenceEngine target library named as `inference_engine_transformations`
+and located in `inference-engine/src/transformations` directory.
+
Transformations root directory contains two folders:
1. ngraph_ops - legacy opset operations needed for nGraph to CNNNetwork conversion.
> **Note**: this operation are prohibited to use inside new plugins until they are not moved to separate directory with allowed operations.
> **Note**: do not use transformation that belongs to `ngraph::pass::ConvertOpSet1ToLegacy` transformations until they are not moved to separate directory with allowed transformations.
Transformation flow in transformation library has several layers:
-1. Pass managers - executes list of transformations using `*_tbl.hpp` file. For example conversion form OpSetX to OpSetY.
-2. Transformations - performs particular transformation algorithm on `ngraph::Function`. Find more about transformations in [Transformations types](#transformations_types).
-3. Low level functions that takes set of nodes and performs some transformation action. They are not mandatory and all transformation code can be located inside transformation. But if some transformation parts can potentially be reused in other transformations we suggest to keep them as a separate functions.
+1. Pass managers - executes any type of transformations and provides additional debug capabilities.
+2. Transformations - performs particular transformation algorithm on `ngraph::Function`.
+3. Low level functions that takes set of nodes and performs some transformation action.
+They are not mandatory and all transformation code can be located inside transformation.
+But if some transformation parts can potentially be reused in other transformations we suggest to keep them as a separate functions.
To decide where to store your transformation code please follow these rules:
1. If it's plugin specific transformation and can't be reused by other plugins keep source code inside plugin.
## Table of Contents:
-1. [`ngraph::Function` and graph representation](#ngraph_function)
-2. [Transformations types](#transformations_types)
-3. [Pattern matching](#pattern_matching)
-4. [Working with ngraph::Function](#working_with_ngraph_function)
-5. [Transformation writing essentials](#transformation_writing_essentials)
-6. [Common mistakes in transformations](#common_mistakes)
-7. [Using pass manager](#using_pass_manager)
-8. [How to debug transformations](#how_to_debug_transformations)
-9. [Disabling/Enabling specific transformations for plugin X](#disabling_transformation)
-10. [Transformations testing](#transformations_testing)
+### 1. [`ngraph::Function` and graph representation](#ngraph_function)
+### 2. [Transformations types](#transformations_types)
+### 2.1 [Function pass](#function_pass)
+### 2.2 [Matcher pass](#matcher_pass)
+### 2.3 [GraphRewrite pass](#graph_rewrite_pass)
+### 3. [Pattern matching](#pattern_matching)
+### 4. [Working with ngraph::Function](#working_with_ngraph_function)
+### 5. [Transformation writing essentials](#transformation_writing_essentials)
+### 6. [Common mistakes in transformations](#common_mistakes)
+### 7. [Using pass manager](#using_pass_manager)
+### 8. [How to debug transformations](#how_to_debug_transformations)
+### 9. [Disabling/Enabling specific transformations for plugin X](#disabling_transformation)
+### 10. [Transformations testing](#transformations_testing)
## ngraph::Function and graph representation <a name="ngraph_function"></a>
## Transformations types <a name="transformations_types"></a>
-There are two main transformation types:
+nGraph has tree main transformation types: `ngraph::pass::FunctionPass` - strait forward way to work with `ngraph::Function` directly;
+`ngraph::pass::MatcherPass` - pattern based transformation approach; `ngraph::pass::GraphRewrite` - container for matcher passes.
-###1. ngraph::pass::FunctionalPass
+###1. ngraph::pass::FunctionPass <a name="function_pass"></a>
-ngraph::pass::FunctionalPass is used for transformations that take entire `ngraph::Function` as input and process it.
+`ngraph::pass::FunctionPass` is used for transformations that take entire `ngraph::Function` as input and process it.
Template for FunctionPass transformation class
@snippet src/template_function_transformation.cpp function_pass:template_transformation_cpp
-Using `ngraph::FunctionPass` you need to override `run_on_function` method where you will write transformation code. Return value must be `true` if original function has changed during transformation (new operation were added or operations replacement was made or node attributes were changed) otherwise it must be `false`. For transformation API please follow [working with ngraph::Function](#working_with_ngraph_function) section.
+Using `ngraph::FunctionPass` you need to override `run_on_function` method where you will write transformation code.
+Return value must be `true` if original function has changed during transformation (new operation were added or operations replacement was made or node attributes were changed) otherwise it must be `false`.
+For transformation API please follow [working with ngraph::Function](#working_with_ngraph_function) section.
+Also `ngraph::FunctionPass` based transformations can be executed via `pass::Manager`. See examples in [Using pass manager](#using_pass_manager) section.
-###2. ngraph::pass::GraphRewrite
+###2. ngraph::pass::MatcherPass <a name="matcher_pass"></a>
-`ngraph::pass::GraphRewrite` is used for pattern based transformations.
+`ngraph::pass::MatcherPass` is used for pattern based transformations.
-Template for GraphRewrite transformation class
+Template for MatcherPass transformation class
@snippet src/template_pattern_transformation.hpp graph_rewrite:template_transformation_hpp
@snippet src/template_pattern_transformation.cpp graph_rewrite:template_transformation_cpp
-Using `ngraph::GraphRewrite` you need to complete three steps:
-1. Create pattern using nGraph operations.
-2. Implement callback.
-3. Register pattern and Matcher.
+Using `ngraph::pass::MatcherPass` you need to complete these steps:
+1. Create pattern
+2. Implement callback
+3. Register pattern and Matcher
+4. MatcherPass execution
So let's go though each of this steps.
-Pattern is a single root `ngraph::Function`. But the only difference is that you don't need to create function object, you just create and connect nGraph operations then take the last one and put it as a root of the pattern.
+### Create pattern
+Pattern is a single root `ngraph::Function`. But the only difference is that you don't need to create function object, you just create and connect nGraph or special pattern operations.
+And then take the last created operation and put it as a root of the pattern. This root node will be used as a root node in pattern matching.
+> **Note**: any nodes in pattern that have no consumers and not registered as root won't be used in pattern matching.
@snippet example_ngraph_utils.cpp pattern:simple_example
-You may have noticed that `Parameter` operation in example has type and shape specified. These attributes are needed only to create Parameter operation class and not used in pattern matching.
-But what if we want to match pattern where `ShapeOf` takes any operation as input? To find an answer to this question please follow [pattern matching](#pattern_matching) section.
+You may have noticed that `Parameter` operation in example has type and shape specified. These attributes are needed only to create Parameter operation class and won't be used in pattern matching.
+
+But what if we want to match pattern where `ShapeOf` takes any operation as input? To find an answer please follow [pattern matching](#pattern_matching) section.
-What is callback? Callback is an action applied to every pattern entrance. In general callback is lambda function that takes Matcher object with detected sub-graph.
+### Implement callback
+Callback is an action applied to every pattern entrance. In general callback is lambda function that takes Matcher object with detected sub-graph.
@snippet example_ngraph_utils.cpp pattern:callback_example
Example above shows callback structure and how Matcher can be used for accessing nodes detected by pattern.
-Callback return value must be `true` if root node was replaced and next pattern can't be applied to the same root node otherwise it must be `false`.
+Callback return value must be `true` if root node was replaced and another pattern can't be applied to the same root node otherwise it must be `false`.
+> **Note**: it's not recommended to manipulate with nodes that are under root node. This may affect GraphRewrite execution as it's expected that all nodes that comes after root node in topological order are valid and can be used in pattern matching.
+
+MatcherPass also provides functionality that allows to report which newly created nodes can be used in additional pattern matching.
+If MatcherPass was registered in `pass::Manager` or `pass::GraphRewrite` then this registered nodes will be added for additional pattern matching.
+That means that matcher passes registered in `pass::GraphRewrite` will be applied to this nodes.
+
+Example below shows how single MatcherPass can fuse sequence of operations using `register_new_node` method.
+
+@snippet src/template_pattern_transformation.cpp matcher_pass:relu_fusion
+
+> **Note**: if you register multiple nodes please add them in topological order. We do not topologically sort this nodes as it's time consuming operation.
-And the last step is to register Matcher and callback inside GraphRewrite pass. And to do this you need to call `add_matcher` method.
+### Register pattern and Matcher
+The last step is to register Matcher and callback inside MatcherPass pass. And to do this you need to call `register_matcher` method.
+> **Note**: Only one matcher can be registered for single MatcherPass class.
```cpp
// Register matcher and callback
-this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+this->register_matcher(m, callback);
```
+### Matcher pass execution
+MatcherPass has multiple ways to be executed:
+1. Run on a single node - it can be useful if you want to run MatcherPass inside another transformation.
+@snippet src/template_pattern_transformation.cpp matcher_pass:run_on_node
+2. Run on `ngraph::Function` using GraphRewrite - this approach gives ability to run MatcherPass on whole `ngraph::Functoin`. Moreover multiple MatcherPass transformation can be registered in a single GraphRewite to be executed in a single graph traversal.
+@snippet src/template_pattern_transformation.cpp matcher_pass:graph_rewrite
+3. Run on `ngraph::Function` using `pass::Manager` - this approach helps you to register MatcherPass for execution on `ngraph::Function` as another transformation types.
+@snippet src/template_pattern_transformation.cpp matcher_pass:manager
-Also you can have multiple matchers and callbacks and they can be registered in single Graphrewrite pass. In this case all registered patterns will be applied in a singe graph traversal.
-```cpp
-// Multiple matchers example
-this->add_matcher(m1, callback1, PassProperty::CHANGE_DYNAMIC_STATE);
-this->add_matcher(m2, callback2, PassProperty::CHANGE_DYNAMIC_STATE);
-```
+###3. ngraph::pass::GraphRewrite <a name="graph_rewrite_pass"></a>
-The last argument `PassProperty::CHANGE_DYNAMIC_STATE` says that callback can be applied for ngraph::Function with dynamic shapes. In case if callback does not support dynamic shapes `PassProperty::REQUIRE_STATIC_SHAPE` can be used.
-> **Note**: property mechanism will be deprecated soon and PassProperty::CHANGE_DYNAMIC_STATE is suggested to be used by default.
+GraphRewrite pass serves for running multiple matcher passes on `ngraph::Function` in a single graph traversal.
+Example:
+
+@snippet src/template_pattern_transformation.cpp matcher_pass:graph_rewrite
+
+In addition GraphRewrite handles nodes that were registered by MatcherPasses during their execution. This nodes will be added to the beginning of sequence with nodes for pattern matching.
+
+> **Note**: when using `pass::Manager` temporary GraphRewrite is used to execute single MatcherPass.
-To run any transformation you need to call `un_on_function(f)` method where `f` is `ngraph::Function`.
-```cpp
-ngraph::pass::MyTransformationClass().run_on_function(f);
-```
-
## Pattern matching <a name="pattern_matching"></a>
Sometimes patterns can't be expressed via regular nGraph operations. For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
And for these cases nGraph provides additional helpers to construct patterns for GraphRewrite transformations.
+
There are two main helpers:
1. `ngraph::pattern::op::Label` - helps to express inputs if their type is undefined.
2. `ngraph::pattern::op::Any` - helps to express intermediate nodes of pattern if their type is unknown.
@snippet example_ngraph_utils.cpp pattern:concat_example
-This example shows how to use predicate to construct pattern where operation has two different types.
+This example shows how to use predicate to construct pattern where operation has two different types. Also it shows how to match pattern manually on given node.
@snippet example_ngraph_utils.cpp pattern:predicate_example
-TODO: add examples for ngraph::pattern::op::Any
+> **Note**: be careful with manual matching because Matcher object holds matched nodes. To clear match use m->clear_state() method.
## Working with ngraph::Function <a name="working_with_ngraph_function"></a>
* If you replace node with another node that produce different shape you need to remember that new shape won't be propagated until first `validate_nodes_and_infer_types` call for `ngraph::Function`. If you are using `pass::Manager` it will automatically call this method after each transformation execution.
* Do not forget to call `ngraph::ConstantFolding` pass if your transformation creates constant sub-graphs.
* Use latest OpSet if you are not developing downgrade transformation pass.
+* When developing callback for `ngraph::pass::MatcherPass` do not change nodes that comes after root node in topological order.
## Using pass manager <a name="using_pass_manager"></a>
`ngraph::pass::Manager` is a container class that can store list of transformations and execute them. The main idea of this class is to have high-level representation for grouped list of transformations.
-For example `ngraph::pass::CommonOptimizations` pass manager register list of transformation related to common optimizations. Also `ngraph::pass::Manager` after each transformation executes `f->validate_nodes_and_infer_types()` that help to keep function synchronized.
+It can register and apply any [transformation types](#transformations_types) on function.
In addition `ngraph::pass::Manager` has extended debug capabilities (find more information in [how to debug transformations](#how_to_debug_transformations) section).
Example below shows basic usage of `ngraph::pass::Manager`
-```cpp
-ngraph::pass::Manager pass_manager;
-pass_manager.register_pass<pass::MyTransformationA>();
-pass_manager.register_pass<pass::MyTransformationB>();
-pass_manager.run_passes(f);
-```
-TODO: Advanced pass manager usage.
+@snippet src/template_pattern_transformation.cpp matcher_pass:manager3
+
+Another example how multiple matcher passes can be united into single GraphRewrite.
+
+@snippet src/template_pattern_transformation.cpp matcher_pass:manager2
## How to debug transformations <a name="how_to_debug_transformations"></a>
This topic mostly related to conversion to legacy opset and plugins that based on CNNNetwork but still this mechanism can be applied for other cases.
Let's suppose that plugin X enabled `opset3::StridedSlice` operation support and you want to disable `ngraph::pass::ConvertStridedSliceToCrop` transformation for plugin X.
-To do this you need to extend transformation class with `ngraph::pass::PassParam` class. This class extends transformations class with `transformation_callback` that can be set by plugin that uses legacy conversion.
+To do this you need to create callback on plugin side and pass it to transformation. And also you need to update particular transformation to use this callback.
```cpp
-// Extend transformation class with PassParam
-class ngraph::pass::ConvertStridedSliceToCrop: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
- ...
-}
-
-// Update callback to be able to use transformation_callback if this transformation based on GraphRewrite.
+// Update callback to be able to use m_transformation_callback if this transformation based on GraphRewrite.
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher &m) {
...
}
-// Use transformation_callback not to execute transformation
-if (transformation_callback(node)) {
+// Use transformation_callback not to execute transformation if callback returns true for given node
+if (m_transformation_callback(node)) {
return false;
}
-```
-TODO: link to existing example
+// Implement transformation callback and pass it directly to transformation or pass::Manager
+const auto transformations_callback = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
+ return std::dynamic_pointer_cast<const ::ngraph::opset3::StridedSlice>(node) != nullptr;
+};
+
+// Register transformation and pass callback to pass::Manager
+ngraph::pass::Manager manager;
+manager.register_pass<ngraph::pass::ConvertStridedSliceToCrop>();
+// pass::Manager will set callback to all reistered transformations automatically
+manager.set_callback(transformations_callback);
+manager.run_passes(f);
+```
## Transformations testing <a name="transformations_testing"></a>
@snippet tests/functional/transformations/template_transformations_test.cpp transformation:test
-TODO: insert advanced transformation tests
[ngraph_replace_node]: ../images/ngraph_replace_node.png
[ngraph_insert_node]: ../images/ngraph_insert_node.png
\ No newline at end of file
}
// ! [ngraph_utils:advanced_function]
-void pattern_matcher_examples() {
+void pattern_matcher_examples(std::shared_ptr<Node> node) {
{
// ! [pattern:simple_example]
// Pattern example
auto lin_op = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{},
[](const std::shared_ptr<ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<ngraph::opset3::Multiply>(node) ||
- std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
+ std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
});
auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op, "MultiplyOrAddMatcher");
+// Matcher can be used to match pattern manually on given node
+if (m->match(node->output(0))) {
+ // Successfully matched
+}
// ! [pattern:predicate_example]
}
// ! [ngraph:visualize]
void visualization_example(std::shared_ptr<ngraph::Function> f) {
- std::vector<std::shared_ptr<ngraph::Function> > g{f};
+ ngraph::pass::Manager manager;
// Serialize ngraph::Function to before.svg file before transformation
- ngraph::pass::VisualizeTree("/path/to/file/before.svg").run_on_module(g);
+ manager.register_pass<ngraph::pass::VisualizeTree>("/path/to/file/before.svg");
// Run your transformation
- // ngraph::pass::MyTransformation().run_on_function();
+ // manager.register_pass<ngraph::pass::MyTransformation>();
// Serialize ngraph::Function to after.svg file after transformation
- ngraph::pass::VisualizeTree("/path/to/file/after.svg").run_on_module(g);
+ manager.register_pass<ngraph::pass::VisualizeTree>("/path/to/file/after.svg");
+
+ manager.run_passes(f);
}
// ! [ngraph:visualize]
// Example: register CommonOptimizations transformation from transformations library
passManager.register_pass<ngraph::pass::CommonOptimizations>();
// Example: register plugin specific transformation
- passManager.register_pass<ngraph::pass::MyPatternBasedTransformation>();
+ passManager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
+ passManager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
// Register any other transformations
// ..
// ! [function_pass:template_transformation_cpp]
// template_function_transformation.cpp
-bool MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
+bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
// Example transformation code
std::vector<std::shared_ptr<Node> > nodes;
#include <ngraph/ngraph.hpp>
+namespace ngraph {
+namespace pass {
+
+class MyFunctionTransformation;
+
+} // namespace pass
+} // namespace ngraph
+
// ! [function_pass:template_transformation_hpp]
// template_function_transformation.hpp
-class MyFunctionTransformation: public ngraph::pass::FunctionPass {
+class ngraph::pass::MyFunctionTransformation: public ngraph::pass::FunctionPass {
public:
MyFunctionTransformation() : FunctionPass() {}
//
#include "template_pattern_transformation.hpp"
+#include "template_function_transformation.hpp"
-#include <ngraph/opsets/opset3.hpp>
#include <ngraph/ngraph.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
using namespace ngraph;
// ! [graph_rewrite:template_transformation_cpp]
// template_pattern_transformation.cpp
-void ngraph::pass::MyPatternBasedTransformation::transform() {
+ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
// Pattern example
- auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
+ auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
+ auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto div = std::make_shared<ngraph::opset3::Divide>(input0, input1);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto div = std::dynamic_pointer_cast<ngraph::opset3::Divide> (m.get_match_root());
// We can not apply this transformation in case with integer input data type
if (!div || div->input(0).get_element_type().is_integral()) {
return true;
};
- // Register pattern with divide operaiton as a pattern root node
+ // Register pattern with Divide operation as a pattern root node
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
// Register Matcher
- this->add_matcher(m, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
// ! [graph_rewrite:template_transformation_cpp]
+
+// ! [matcher_pass:relu_fusion]
+ngraph::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
+ auto m_relu1 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>(pattern::consumers_count(1));
+ auto m_relu2 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>({m_relu1});
+
+ ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
+ // Map that helps to connect labels with matched outputs
+ auto& node_to_output = m.get_pattern_value_map();
+
+ // Create new Relu operation and add register it for additional execution
+ auto new_relu = register_new_node<ngraph::opset3::Relu>(
+ node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));
+
+ // Copy runtime info attributes to newly created operation
+ ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu);
+
+ // Save last Relu name to new Relu operation
+ new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());
+
+ // Replace Relu->Relu with Relu
+ ngraph::replace_node(m.get_match_root(), new_relu);
+
+ // Return true as the root node was changed
+ return true;
+ };
+
+ // Register pattern with Relu operation as a pattern root node
+ auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
+ // Register Matcher
+ this->register_matcher(m, callback);
+}
+// ! [matcher_pass:relu_fusion]
+
+void run_matcher_on_node(std::shared_ptr<ngraph::Node> node) {
+// ! [matcher_pass:run_on_node]
+if (ngraph::pass::DecomposeDivideMatcher().apply(node)) {
+ // successful execution (root node was replaced)
+}
+// ! [matcher_pass:run_on_node]
+}
+
+void run_matcher_with_manager(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:manager]
+// Two matchers will run independently (two independent graph traversals)
+// pass::Manager automatically creates GraphRewrite container for each MatcherPass
+pass::Manager manager;
+manager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
+manager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
+manager.run_passes(f);
+// ! [matcher_pass:manager]
+}
+
+void run_matcher_with_manager2(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:manager2]
+// Register anchor GraphRewrite pass inside manager that will execute two matchers simultaneously
+pass::Manager manager;
+auto anchor = manager.register_pass<ngraph::pass::GraphRewrite>();
+anchor->add_matcher<ngraph::pass::DecomposeDivideMatcher>();
+anchor->add_matcher<ngraph::pass::ReluReluFusionMatcher>();
+manager.run_passes(f);
+// ! [matcher_pass:manager2]
+}
+
+void run_matcher_with_manager3(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:manager3]
+pass::Manager manager;
+manager.register_pass<ngraph::pass::MyFunctionTransformation>();
+// Two matchers will run independently (two independent graph traversals)
+// pass::Manager automatically creates GraphRewrite container for each MatcherPass
+manager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
+manager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
+manager.run_passes(f);
+// ! [matcher_pass:manager3]
+}
+
+void run_matcher_with_gr(std::shared_ptr<ngraph::Function> f) {
+// ! [matcher_pass:graph_rewrite]
+// Two matcher passes will run simultaneously in a single graph traversal
+ngraph::pass::GraphRewrite pass;
+pass.add_matcher<ngraph::pass::DecomposeDivideMatcher>();
+pass.add_matcher<ngraph::pass::ReluReluFusionMatcher>();
+pass.run_on_function(f);
+// ! [matcher_pass:graph_rewrite]
+}
namespace ngraph {
namespace pass {
-class MyPatternBasedTransformation;
+class DecomposeDivideMatcher;
+class ReluReluFusionMatcher;
} // namespace pass
} // namespace ngraph
// ! [graph_rewrite:template_transformation_hpp]
// template_pattern_transformation.hpp
-class ngraph::pass::MyPatternBasedTransformation: public ngraph::pass::GraphRewrite {
+class ngraph::pass::DecomposeDivideMatcher: public ngraph::pass::MatcherPass {
public:
- MyPatternBasedTransformation() : GraphRewrite() {
- transform();
- }
-
-private:
- void transform();
+ DecomposeDivideMatcher();
};
// ! [graph_rewrite:template_transformation_hpp]
+
+class ngraph::pass::ReluReluFusionMatcher: public ngraph::pass::MatcherPass {
+public:
+ ReluReluFusionMatcher();
+};
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/pass/manager.hpp>
#include <generic_ie.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
// Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
- ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::CommonOptimizations>();
+ manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+ manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+ manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+ manager.set_callback(transformations_callback);
+ manager.run_passes(nGraphFunc);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
}
#include <vector>
#include <unordered_set>
#include <ngraph/ngraph.hpp>
+#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/get_output_element_elimination.hpp>
#include <set>
#include <string>
// Call this transformation because OneHot IE and nGraph have different output precisions
{
IE_PROFILING_AUTO_SCOPE(ConvertOneHot);
- ::ngraph::pass::ConvertOneHotToOneHotIE().run_on_function(specialized_ngraph_function);
+ ::ngraph::pass::Manager manager;
+ manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(specialized_ngraph_function);
+ manager.run_passes(specialized_ngraph_function);
}
specialized_ngraph_function->validate_nodes_and_infer_types();
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/fused/gelu.hpp>
#include <ngraph/op/util/op_types.hpp>
+#include <ngraph/pass/manager.hpp>
#include "ngraph_ops/fully_connected.hpp"
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
// Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
- ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::CommonOptimizations>();
+ manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+ manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+ manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+ manager.set_callback(transformations_callback);
+ manager.run_passes(nGraphFunc);
+
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
}
} // namespace pass
} // namespace ngraph
-class ngraph::pass::BatchNormDecomposition: public ngraph::pass::GraphRewrite {
+class ngraph::pass::BatchNormDecomposition: public ngraph::pass::MatcherPass {
public:
- BatchNormDecomposition() : GraphRewrite() {
- batch_norm_decomposition();
- }
-
-private:
- void batch_norm_decomposition();
+ BatchNormDecomposition();
};
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::CommonOptimizations: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::CommonOptimizations: public ngraph::pass::FunctionPass {
public:
- explicit CommonOptimizations(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
- : FunctionPass(), PassParam(callback) {}
-
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
+++ /dev/null
-// Copyright (C) 2018-2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#pragma once
-
-#include <transformations_visibility.hpp>
-
-#include <ngraph/pass/graph_rewrite.hpp>
-
-namespace ngraph {
-namespace pass {
-
-class TRANSFORMATIONS_API ConstantEltwiseReduction;
-
-} // namespace pass
-} // namespace ngraph
-
-class ngraph::pass::ConstantEltwiseReduction: public ngraph::pass::GraphRewrite {
-public:
- ConstantEltwiseReduction() : GraphRewrite() {
- constant_multiply_reduction();
- constant_add_reduction();
- }
-
-private:
- void constant_multiply_reduction();
- void constant_add_reduction();
-};
#include <ngraph/ops.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertBatchToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertBatchToSpace: public ngraph::pass::GraphRewrite {
public:
- ConvertBatchToSpace() : GraphRewrite(), PassParam() {
+ ConvertBatchToSpace() : GraphRewrite() {
// convert_batch_to_space();
convert_batch_to_space_ie_side();
}
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertBroadcastToTiles: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertBroadcastToTiles: public ngraph::pass::MatcherPass {
public:
- ConvertBroadcastToTiles() : GraphRewrite() {
- convert_broadcast_to_tiles();
- }
-
-private:
- void convert_broadcast_to_tiles();
+ ConvertBroadcastToTiles();
};
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::MatcherPass {
public:
- ConvertDepthToSpace() : GraphRewrite(), PassParam() {
- convert_depth_to_space();
- }
-
-private:
- void convert_depth_to_space();
+ ConvertDepthToSpace();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertDivide: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertDivide: public ngraph::pass::MatcherPass {
public:
- ConvertDivide() : GraphRewrite() {
- convert_divide();
- }
-
-private:
- void convert_divide();
+ ConvertDivide();
};
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/op/fused/gelu.hpp"
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertGELU: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertGELU: public ngraph::pass::GraphRewrite {
public:
- ConvertGELU() : GraphRewrite(), PassParam() {
+ ConvertGELU() : GraphRewrite() {
convert_gelu();
}
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertMinimum: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertMinimum: public ngraph::pass::MatcherPass {
public:
- ConvertMinimum() : GraphRewrite() {
- convert_minimum();
- }
-
-private:
- void convert_minimum();
+ ConvertMinimum();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertMod: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertMod: public ngraph::pass::MatcherPass {
public:
- ConvertMod() : GraphRewrite() {
- convert_mod();
- }
-
-private:
- void convert_mod();
+ ConvertMod();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertNegative: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNegative: public ngraph::pass::MatcherPass {
public:
- ConvertNegative() : GraphRewrite() {
- convert_negative();
- }
-
-private:
- void convert_negative();
+ ConvertNegative();
};
namespace pass {
class TRANSFORMATIONS_API ConvFusion;
+class TRANSFORMATIONS_API ConvAddFusion;
+class TRANSFORMATIONS_API ConvMultiplyFusion;
+class TRANSFORMATIONS_API DeconvAddFusion;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvFusion: public ngraph::pass::GraphRewrite {
public:
ConvFusion() : GraphRewrite() {
- fuse_convolution_with<op::ConvolutionIE, opset1::Multiply>();
- fuse_convolution_with<op::ConvolutionIE, opset1::Add>();
- fuse_convolution_with<op::DeconvolutionIE, opset1::Add>();
+ add_matcher<ngraph::pass::ConvAddFusion>();
+ add_matcher<ngraph::pass::ConvMultiplyFusion>();
+ add_matcher<ngraph::pass::DeconvAddFusion>();
}
-
-private:
- template <class Conv, class Eltwise>
- void fuse_convolution_with();
-
- template <class Conv>
- ngraph::graph_rewrite_callback get_callback();
};
-template <class Conv, class Eltwise>
-void ngraph::pass::ConvFusion::fuse_convolution_with() {
- static_assert(std::is_same<Eltwise, ngraph::opset1::Multiply>() || std::is_same<Eltwise, ngraph::opset1::Add>(),
- "This transformation works only with ngraph::opset1::Add and ngraph::opset1::Multiply");
-
- static_assert(std::is_same<Conv, ngraph::op::ConvolutionIE>() || std::is_same<Conv, ngraph::op::DeconvolutionIE>(),
- "This transformation works only with ngraph::op::ConvolutionIE and ngraph::op::DeconvolutionIE");
-
- auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
- [](const std::shared_ptr<Node> & node) -> bool {
- return std::dynamic_pointer_cast<ngraph::op::ConvolutionIE>(node) ||
- std::dynamic_pointer_cast<ngraph::op::DeconvolutionIE>(node);
- });
-
- auto last = std::make_shared<Eltwise>(conv, std::make_shared<pattern::op::Label>(element::f32, Shape{1}));
-
- auto m = std::make_shared<ngraph::pattern::Matcher>(last, "ConvFusion");
- this->add_matcher(m, get_callback<Conv>(), PassProperty::CHANGE_DYNAMIC_STATE);
-}
-
-template <class Conv>
-ngraph::graph_rewrite_callback ngraph::pass::ConvFusion::get_callback() {
- ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
- auto eltwise = m.get_match_root();
-
- std::shared_ptr<op::Constant> m_const;
- std::shared_ptr<Conv> m_conv;
- // FIXME: use auto [m_conv, m_const] when C++17 is available
- std::tie(m_conv, m_const) = parse_eltwise_inputs<Conv, op::Constant>(eltwise);
- if (!m_conv || !m_const) {
- return false;
- }
-
- // TODO: check that constant can be scalar and do not match [1, C, 1, 1] layout
- const auto constant_shape = m_const->get_shape();
- const auto output_pshape = m_conv->get_output_partial_shape(0);
-
- if (output_pshape.rank().is_dynamic() || output_pshape[1].is_dynamic()) {
- return false;
- }
-
- const auto channel_dim = output_pshape[1].get_length();
-
- size_t constant_size = std::accumulate(constant_shape.begin(), constant_shape.end(), 1, std::multiplies<size_t>());
- if (constant_size != channel_dim) {
- return false;
- }
-
- Output<Node> constant(m_const);
-
- if (constant_shape.size() > 1) {
- constant = std::make_shared<opset1::Reshape>(constant, op::Constant::create(element::i64, Shape{1}, {channel_dim}), true);
- }
-
- if (m_conv->output(0).get_target_inputs().size() != 1) {
- return false;
- }
-
- Output<Node> new_conv, new_weights, new_bias;
- if (std::dynamic_pointer_cast<opset1::Add>(eltwise)) {
- // Fuse: ConvolutionIE/DeconvolutionIE->Add
- if (m_conv->inputs().size() == 2) {
- new_bias = constant;
- } else {
- new_bias = std::make_shared<opset1::Add>(constant, m_conv->input_value(2));
- }
- new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), m_conv->input_value(1), new_bias});
- } else if (std::is_same<Conv, op::ConvolutionIE>() && std::dynamic_pointer_cast<opset1::Multiply>(eltwise)) {
- // Fuse: ConvolutionIE->Mul
- auto weights_shape = m_conv->input(1).get_shape();
-
- Shape const_shape(weights_shape.size(), 1);
- const_shape[0] = weights_shape[0];
-
- auto const_reshape = std::make_shared<opset1::Reshape>(constant,
- op::Constant::create(element::i64, Shape{const_shape.size()}, const_shape), true);
- new_weights = std::make_shared<opset1::Multiply> (m_conv->input_value(1), const_reshape);
- if (m_conv->inputs().size() == 2) {
- new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights});
- } else {
- auto bias_reshape = std::make_shared<opset1::Reshape>(constant, op::Constant::create(element::i64, Shape{1}, {weights_shape[0]}), true);
- new_bias = std::make_shared<opset1::Multiply>(bias_reshape, constant);
- new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights, new_bias});
- }
- } else {
- return false;
- }
+class ngraph::pass::ConvAddFusion: public ngraph::pass::MatcherPass {
+public:
+ ConvAddFusion();
+};
- ngraph::copy_runtime_info({m_conv, eltwise}, new_conv.get_node_shared_ptr());
- new_conv.get_node_shared_ptr()->set_friendly_name(m.get_match_root()->get_friendly_name());
- ngraph::replace_node(m.get_match_root(), new_conv.get_node_shared_ptr());
- return true;
- };
- return callback;
-}
+class ngraph::pass::ConvMultiplyFusion: public ngraph::pass::MatcherPass {
+public:
+ ConvMultiplyFusion();
+};
+class ngraph::pass::DeconvAddFusion: public ngraph::pass::MatcherPass {
+public:
+ DeconvAddFusion();
+};
\ No newline at end of file
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertCellsToCellsIE;
+class TRANSFORMATIONS_API ConvertLSTMCellMatcher;
+class TRANSFORMATIONS_API ConvertGRUCellMatcher;
+class TRANSFORMATIONS_API ConvertRNNCellMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertCellsToCellsIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertLSTMCellMatcher : public ngraph::pass::MatcherPass {
public:
- ConvertCellsToCellsIE() : GraphRewrite() {
- convert_lstm_cell();
- convert_gru_cell();
- convert_rnn_cell();
- }
-
-private:
- void convert_lstm_cell();
- void convert_gru_cell();
- void convert_rnn_cell();
+ ConvertLSTMCellMatcher();
+};
+
+class ngraph::pass::ConvertGRUCellMatcher : public ngraph::pass::MatcherPass {
+public:
+ ConvertGRUCellMatcher();
+};
+
+class ngraph::pass::ConvertRNNCellMatcher : public ngraph::pass::MatcherPass {
+public:
+ ConvertRNNCellMatcher();
};
class TRANSFORMATIONS_API ConvertConvolutions;
+class TRANSFORMATIONS_API ConvertConvolution;
+class TRANSFORMATIONS_API ConvertGroupConvolution;
+class TRANSFORMATIONS_API ConvertDeconvolution;
+class TRANSFORMATIONS_API ConvertGroupDeconvolution;
+
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertConvolutions: public ngraph::pass::GraphRewrite {
public:
- ConvertConvolutions() : GraphRewrite() {
- convert_convolution();
- convert_group_convolution();
- convert_convolution_backprop_data();
- convert_group_convolution_backprop_data();
+ ConvertConvolutions() {
+ add_matcher<ngraph::pass::ConvertConvolution>();
+ add_matcher<ngraph::pass::ConvertGroupConvolution>();
+ add_matcher<ngraph::pass::ConvertDeconvolution>();
+ add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
}
+};
-private:
- void convert_convolution();
-
- void convert_group_convolution();
+class ngraph::pass::ConvertConvolution: public ngraph::pass::MatcherPass {
+public:
+ ConvertConvolution();
+};
- void convert_convolution_backprop_data();
+class ngraph::pass::ConvertGroupConvolution: public ngraph::pass::MatcherPass {
+public:
+ ConvertGroupConvolution();
+};
- void convert_group_convolution_backprop_data();
+class ngraph::pass::ConvertDeconvolution: public ngraph::pass::MatcherPass {
+public:
+ ConvertDeconvolution();
};
+
+class ngraph::pass::ConvertGroupDeconvolution: public ngraph::pass::MatcherPass {
+public:
+ ConvertGroupDeconvolution();
+};
\ No newline at end of file
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertGatherToGatherIE;
+class TRANSFORMATIONS_API ConvertGatherToGatherIEMatcher;
} // namespace pass
} // namespace ngraph
* we unsqueeze indices input and squeeze GatherIE output.
*/
-class ngraph::pass::ConvertGatherToGatherIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertGatherToGatherIEMatcher : public ngraph::pass::MatcherPass {
public:
- ConvertGatherToGatherIE() : GraphRewrite() {
- convert_gather_to_gather_ie();
- }
-
-private:
- void convert_gather_to_gather_ie();
+ ConvertGatherToGatherIEMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertGatherTreeToGatherTreeIE;
+class TRANSFORMATIONS_API ConvertGatherTreeToGatherTreeIEMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertGatherTreeToGatherTreeIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertGatherTreeToGatherTreeIE() : GraphRewrite() {
- convert();
- }
-
-private:
- void convert();
+ ConvertGatherTreeToGatherTreeIEMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertHardSigmoidToHardSigmoidIE;
+class TRANSFORMATIONS_API ConvertHardSigmoidToLegacyMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertHardSigmoidToHardSigmoidIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertHardSigmoidToLegacyMatcher : public ngraph::pass::MatcherPass {
public:
- ConvertHardSigmoidToHardSigmoidIE() : GraphRewrite() {
- convert_hard_sigmoid();
- }
-
-private:
- void convert_hard_sigmoid();
+ ConvertHardSigmoidToLegacyMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertInterpolateToInterpOrResample;
+class TRANSFORMATIONS_API ConvertInterpolateToInterpOrResampleMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertInterpolateToInterpOrResample: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertInterpolateToInterpOrResample() : GraphRewrite() {
- convert_interpolate_to_interp_or_resample();
- }
-
-private:
- void convert_interpolate_to_interp_or_resample();
+ ConvertInterpolateToInterpOrResampleMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertLRNToLRNIE;
+class TRANSFORMATIONS_API ConvertLRNToLegacyMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertLRNToLRNIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertLRNToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertLRNToLRNIE() : GraphRewrite() {
- convert_lrn();
- }
-
-private:
- void convert_lrn();
+ ConvertLRNToLegacyMatcher();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::MatcherPass {
public:
- ConvertMatMulToFCorGemm(): GraphRewrite() {
- convert_matmul();
- }
-
-private:
- void convert_matmul();
+ ConvertMatMulToFCorGemm();
};
namespace ngraph {
namespace pass {
- class TRANSFORMATIONS_API ConvertNMS4ToLegacy;
+ class TRANSFORMATIONS_API ConvertNMS4ToLegacyMatcher;
} // namespace pass
} // namespace ngraph
*/
-class ngraph::pass::ConvertNMS4ToLegacy: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNMS4ToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertNMS4ToLegacy() : GraphRewrite() {
- convert_nms4_to_legacy();
- }
-private:
- void convert_nms4_to_legacy();
+ ConvertNMS4ToLegacyMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertNMSToNMSIE;
+class TRANSFORMATIONS_API ConvertNMSToNMSIEMatcher;
} // namespace pass
} // namespace ngraph
* we insert Unsqueeze operations.
*/
-class ngraph::pass::ConvertNMSToNMSIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNMSToNMSIEMatcher : public ngraph::pass::MatcherPass {
public:
- ConvertNMSToNMSIE() : GraphRewrite() {
- convert_nms_to_nms_ie();
- }
-
-private:
- void convert_nms_to_nms_ie();
+ ConvertNMSToNMSIEMatcher();
};
namespace pass {
class TRANSFORMATIONS_API ConvertNormalizeL2WithMulToNormalizeIE;
-class TRANSFORMATIONS_API ConvertNormalizeL2ToNormalizeIE;
+class TRANSFORMATIONS_API ConvertNormalizeL2ToLegacyMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE: public ngraph::pass::MatcherPass {
public:
- ConvertNormalizeL2WithMulToNormalizeIE() : GraphRewrite() {
- convert_normalize_l2_with_mul();
- }
-
-private:
- void convert_normalize_l2_with_mul();
+ ConvertNormalizeL2WithMulToNormalizeIE();
};
-class ngraph::pass::ConvertNormalizeL2ToNormalizeIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertNormalizeL2ToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertNormalizeL2ToNormalizeIE() : GraphRewrite() {
- convert_normalize_l2();
- }
-
-private:
- void convert_normalize_l2();
+ ConvertNormalizeL2ToLegacyMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertOneHotToOneHotIE;
+class TRANSFORMATIONS_API ConvertOneHotToOneHotIEMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertOneHotToOneHotIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertOneHotToOneHotIEMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertOneHotToOneHotIE() : GraphRewrite(), is_f16(false) {
- convert_one_hot();
- }
+ ConvertOneHotToOneHotIEMatcher();
- bool run_on_function(std::shared_ptr<ngraph::Function> f) final;
+ void detect_output_type(const std::shared_ptr<Function> & f);
private:
- void convert_one_hot();
- bool is_f16;
-};
+ element::Type m_output_type = element::Type_t::f32;
+};
\ No newline at end of file
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertOpSet1ToLegacy: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertOpSet1ToLegacy: public ngraph::pass::FunctionPass {
public:
- explicit ConvertOpSet1ToLegacy(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
- : FunctionPass(), PassParam(callback) {}
-
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
+++ /dev/null
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#ifndef NGRAPH_PASS
-#warning "NGRAPH_PASS is not defined"
-#define NGRAPH_PASS(A, B)
-#endif
-
-// To register new pass you need to define NGRAPH_PASS
-// Usage example:
-// ngraph::pass:Manager pm;
-// #define NGRAPH_PASS(NAME, NAMESPACE) pm.register_pass<NAMESPACE::NAME>();
-// #include <transformations/transformations_tbl.hpp>
-// #undef NGRAPH_PASS
-
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertReduceToPooling, ::ngraph::pass)
-NGRAPH_PASS(ConvertMod, ::ngraph::pass)
-NGRAPH_PASS(ConvertMinimum, ::ngraph::pass)
-NGRAPH_PASS(ConvertSubtract, ::ngraph::pass)
-NGRAPH_PASS(ConvertDivide, ::ngraph::pass)
-NGRAPH_PASS(ConvertNegative, ::ngraph::pass)
-NGRAPH_PASS(ConvertDepthToSpace, ::ngraph::pass)
-NGRAPH_PASS(ConvertSpaceToDepth, ::ngraph::pass)
-NGRAPH_PASS(ConvertConvolutions, ::ngraph::pass)
-NGRAPH_PASS(BatchNormDecomposition, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(MulAddVerification, ::ngraph::pass)
-NGRAPH_PASS(MulAddFusion, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertMatMulToFCorGemm, ::ngraph::pass)
-NGRAPH_PASS(PullTransposeThroughFQUp, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvFusion, ::ngraph::pass)
-NGRAPH_PASS(FullyConnectedBiasFusion, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ReshapeFullyConnected, ::ngraph::pass)
-NGRAPH_PASS(ReshapeFullyConnectedFusion, ::ngraph::pass)
-NGRAPH_PASS(Reshape1DOps, ::ngraph::pass)
-NGRAPH_PASS(ConvertNormalizeL2WithMulToNormalizeIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertNormalizeL2ToNormalizeIE, ::ngraph::pass)
-NGRAPH_PASS(ConstantEltwiseReduction, ::ngraph::pass)
-NGRAPH_PASS(ConvertMulAddToScaleShiftOrPower, ::ngraph::pass)
-NGRAPH_PASS(ConvertMulOrAddFinally, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertBroadcastToTiles, ::ngraph::pass)
-NGRAPH_PASS(ConvertTileToIETile, ::ngraph::pass)
-NGRAPH_PASS(ConvertProposalToProposalIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertLRNToLRNIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertPadToPadIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertHardSigmoidToHardSigmoidIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertCellsToCellsIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertInterpolateToInterpOrResample, ::ngraph::pass)
-NGRAPH_PASS(ConvertStridedSliceToCrop, ::ngraph::pass)
-NGRAPH_PASS(ConvertPowerToPowerIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertSqrtToPowerIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertPReLUToReLUIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertGatherToGatherIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertSeluToSeluIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertOneHotToOneHotIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertGatherTreeToGatherTreeIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertTopKToTopKIE, ::ngraph::pass)
-NGRAPH_PASS(ConvertNMSToNMSIE, ::ngraph::pass)
-NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
-NGRAPH_PASS(ConvertNMS4ToLegacy, ::ngraph::pass)
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertPadToPadIE;
+class TRANSFORMATIONS_API ConvertPadToLegacyMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertPadToPadIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertPadToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertPadToPadIE() : GraphRewrite() {
- convert_pad();
- }
-
-private:
- void convert_pad();
+ ConvertPadToLegacyMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertPowerToPowerIE;
+class TRANSFORMATIONS_API ConvertPowerToPowerIEMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertPowerToPowerIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertPowerToPowerIEMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertPowerToPowerIE() : GraphRewrite() {
- convert_power();
- }
-
-private:
- void convert_power();
+ ConvertPowerToPowerIEMatcher();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertPReLUToReLUIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertPReLUToReLUIE: public ngraph::pass::MatcherPass {
public:
- ConvertPReLUToReLUIE() : GraphRewrite() {
- convert_prelu();
- }
-
-private:
- void convert_prelu();
+ ConvertPReLUToReLUIE();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertProposalToProposalIE;
+class TRANSFORMATIONS_API ConvertProposalToLegacyMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertProposalToProposalIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertProposalToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertProposalToProposalIE() : GraphRewrite() {
- convert_proposal();
- }
-
-private:
- void convert_proposal();
+ ConvertProposalToLegacyMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertSeluToSeluIE;
+class TRANSFORMATIONS_API ConvertSeluToSeluIEMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertSeluToSeluIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSeluToSeluIEMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertSeluToSeluIE() : GraphRewrite() {
- convert_selu();
- }
-
-private:
- void convert_selu();
+ ConvertSeluToSeluIEMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertSqrtToPowerIE;
+class TRANSFORMATIONS_API ConvertSqrtToPowerIEMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertSqrtToPowerIE: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSqrtToPowerIEMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertSqrtToPowerIE() : GraphRewrite() {
- convert_sqrt();
- }
-
-private:
- void convert_sqrt();
+ ConvertSqrtToPowerIEMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertStridedSliceToCrop;
+class TRANSFORMATIONS_API ConvertStridedSliceToCropMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertStridedSliceToCrop: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertStridedSliceToCropMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertStridedSliceToCrop() : GraphRewrite() {
- convert_strided_slice_to_crop();
- }
-
-private:
- void convert_strided_slice_to_crop();
+ ConvertStridedSliceToCropMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertTileToIETile;
+class TRANSFORMATIONS_API ConvertTileToLegacyMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertTileToIETile: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertTileToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
- ConvertTileToIETile() : GraphRewrite() {
- convert_tile();
- }
-
-private:
- void convert_tile();
+ ConvertTileToLegacyMatcher();
};
namespace ngraph {
namespace pass {
-class TRANSFORMATIONS_API ConvertTopKToTopKIE;
+class TRANSFORMATIONS_API ConvertTopKToTopKIEMatcher;
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertTopKToTopKIE : public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertTopKToTopKIEMatcher : public ngraph::pass::MatcherPass {
public:
- ConvertTopKToTopKIE() : GraphRewrite() {
- convert_topk_to_topk_ie();
- }
-
-private:
- void convert_topk_to_topk_ie();
+ ConvertTopKToTopKIEMatcher();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::FullyConnectedBiasFusion : public ngraph::pass::GraphRewrite {
+class ngraph::pass::FullyConnectedBiasFusion : public ngraph::pass::MatcherPass {
public:
- FullyConnectedBiasFusion() : GraphRewrite() {
- construct_fcbias();
- }
-
-private:
- void construct_fcbias() {
- Shape shape_w{2, 4};
- Shape shape_x{2, 4};
- Shape shape_b{2, 2};
- auto input = std::make_shared<pattern::op::Label>(element::f32, shape_w);
- auto weights = std::make_shared<pattern::op::Label>(element::f32, shape_x);
- auto fc_bias = std::make_shared<pattern::op::Label>(element::f32, shape_b);
- auto bias = std::make_shared<pattern::op::Label>(element::f32, shape_b);
-
- auto fc = std::make_shared<op::FullyConnected>(input, weights, fc_bias, Shape{1, 2});
- auto add = std::make_shared<opset1::Add>(fc, bias);
-
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
- auto add = m.get_match_root();
- auto add_input_0 = add->input(0).get_source_output().get_node_shared_ptr();
- auto add_input_1 = add->input(1).get_source_output().get_node_shared_ptr();
-
- auto m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_0);
- auto m_bias = add_input_1;
-
- if (m_fc == nullptr) {
- m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_1);
- m_bias = add_input_0;
- }
-
- if (auto bcast_m = std::dynamic_pointer_cast<opset1::Broadcast>(m_bias)) {
- m_bias = bcast_m->input(0).get_source_output().get_node_shared_ptr();
- }
-
- if (!std::dynamic_pointer_cast<opset1::Constant>(m_bias)) {
- return false;
- }
- Shape bias_shape(m_bias->get_shape());
-
- if (m_fc->output(0).get_target_inputs().size() != 1) {
- return false;
- }
-
- Shape output_shape(m_fc->get_shape());
- size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<int64_t>());
- if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
- return false;
- }
-
- NodeVector new_ops;
-
- auto new_bias = std::make_shared<opset1::Add>(m_fc->input(2).get_source_output(), m_bias);
- new_ops.push_back(new_bias);
- std::shared_ptr<Node> final_bias = new_bias;
- if (new_bias->get_shape().size() >= 2) {
- final_bias = std::make_shared<opset1::Reshape>(final_bias, opset1::Constant::create(element::i64, Shape{1}, {-1}), true);
- new_ops.push_back(final_bias);
- }
-
- auto new_fc = std::make_shared<op::FullyConnected>(m_fc->input(0).get_source_output(),
- m_fc->input(1).get_source_output(),
- final_bias,
- m_fc->get_shape());
- new_ops.push_back(new_fc);
-
- new_fc->set_friendly_name(add->get_friendly_name());
- ngraph::copy_runtime_info({m_fc, add}, new_ops);
- ngraph::replace_node(add, new_fc);
- return true;
- };
-
- auto m = std::make_shared<ngraph::pattern::Matcher>(add, "FullyConnectedBiasFusion");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
- }
+ FullyConnectedBiasFusion();
};
namespace pass {
class TRANSFORMATIONS_API Reshape1DOps;
+class TRANSFORMATIONS_API Reshape1DConvolution;
+class TRANSFORMATIONS_API Reshape1DAvgPool;
+class TRANSFORMATIONS_API Reshape1DMaxPool;
} // namespace pass
} // namespace ngraph
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
public:
Reshape1DOps() : GraphRewrite() {
- reshape_ops();
+ add_matcher<ngraph::pass::Reshape1DConvolution>();
+ add_matcher<ngraph::pass::Reshape1DAvgPool>();
+ add_matcher<ngraph::pass::Reshape1DMaxPool>();
}
+};
-private:
- void reshape_ops();
+class ngraph::pass::Reshape1DConvolution: public ngraph::pass::MatcherPass {
+public:
+ Reshape1DConvolution();
};
+
+class ngraph::pass::Reshape1DAvgPool: public ngraph::pass::MatcherPass {
+public:
+ Reshape1DAvgPool();
+};
+
+class ngraph::pass::Reshape1DMaxPool: public ngraph::pass::MatcherPass {
+public:
+ Reshape1DMaxPool();
+};
\ No newline at end of file
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
* }
* };
*
- * auto p = ngraph::pass::ReshapeFullyConnected();
- * p.setCallback(callback);
- * p.run_on_function(f);
- *
*/
-class ngraph::pass::ReshapeFullyConnected: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ReshapeFullyConnected: public ngraph::pass::MatcherPass {
public:
- ReshapeFullyConnected() : GraphRewrite(), PassParam() {
- reshape_fully_connected();
- }
-
-private:
- void reshape_fully_connected();
+ ReshapeFullyConnected();
};
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertOpSet2ToOpSet1: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertOpSet2ToOpSet1: public ngraph::pass::FunctionPass {
public:
- explicit ConvertOpSet2ToOpSet1(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
- : FunctionPass(), PassParam(callback) {}
-
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
+++ /dev/null
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#ifndef NGRAPH_PASS
-#warning "NGRAPH_PASS is not defined"
-#define NGRAPH_PASS(A, B)
-#endif
-
-// To register new pass you need to define NGRAPH_PASS
-// Usage example:
-// ngraph::pass:Manager pm;
-// #define NGRAPH_PASS(NAME, NAMESPACE) pm.register_pass<NAMESPACE::NAME>();
-// #include <transformations/transformations_tbl.hpp>
-// #undef NGRAPH_PASS
-
-NGRAPH_PASS(ConvertGELU, ::ngraph::pass)
-NGRAPH_PASS(ConvertSpaceToBatch, ::ngraph::pass)
-NGRAPH_PASS(ConvertBatchToSpace, ::ngraph::pass)
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite {
public:
ConvertBroadcast3() : GraphRewrite() {
convert_broadcast3();
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertNMS3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertNMS3: public ngraph::pass::GraphRewrite {
public:
ConvertNMS3() : GraphRewrite() {
convert_nms3();
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertOpSet3ToOpSet2: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertOpSet3ToOpSet2: public ngraph::pass::FunctionPass {
public:
- explicit ConvertOpSet3ToOpSet2(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
- : FunctionPass(), PassParam(callback) {}
-
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
+++ /dev/null
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#ifndef NGRAPH_PASS
-#warning "NGRAPH_PASS is not defined"
-#define NGRAPH_PASS(A, B)
-#endif
-
-// To register new pass you need to define NGRAPH_PASS
-// Usage example:
-// ngraph::pass:Manager pm;
-// #define NGRAPH_PASS(NAME, NAMESPACE) pm.register_pass<NAMESPACE::NAME>();
-// #include <transformations/transformations_tbl.hpp>
-// #undef NGRAPH_PASS
-
-NGRAPH_PASS(ConvertBroadcast3, ::ngraph::pass)
-NGRAPH_PASS(ConvertNMS3, ::ngraph::pass)
-NGRAPH_PASS(ConvertShapeOf3, ::ngraph::pass)
-NGRAPH_PASS(ConvertShuffleChannels3, ::ngraph::pass)
-NGRAPH_PASS(ConvertTopK3, ::ngraph::pass)
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::GraphRewrite {
public:
- ConvertShuffleChannels3() : GraphRewrite(), PassParam() {
+ ConvertShuffleChannels3() : GraphRewrite() {
convert_shuffle_channels3();
}
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertTopK3: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertTopK3: public ngraph::pass::GraphRewrite {
public:
ConvertTopK3() : GraphRewrite() {
convert_topk3();
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/validation_util.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertReduceToPooling;
+class TRANSFORMATIONS_API ConvertReduceMeanToPooling;
+class TRANSFORMATIONS_API ConvertReduceMaxToPooling;
+class TRANSFORMATIONS_API ConvertReduceSumToPooling;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
public:
- ConvertReduceToPooling() : GraphRewrite() {
- convert_reduce_to_pooling<ngraph::opset1::ReduceMean>();
- convert_reduce_to_pooling<ngraph::opset1::ReduceMax>();
- convert_reduce_to_pooling<ngraph::opset1::ReduceSum>();
+ ConvertReduceToPooling() {
+ add_matcher<ConvertReduceMeanToPooling>();
+ add_matcher<ConvertReduceMaxToPooling>();
+ add_matcher<ConvertReduceSumToPooling>();
}
-
-private:
- template <class T>
- void convert_reduce_to_pooling();
};
template <class T>
-void ngraph::pass::ConvertReduceToPooling::convert_reduce_to_pooling() {
- {
- static_assert(std::is_same<T, ngraph::opset1::ReduceMean>() ||
- std::is_same<T, ngraph::opset1::ReduceMax>() ||
- std::is_same<T, ngraph::opset1::ReduceSum>(),
- "This callback works only with ngraph::opset1::ReduceMean/Max/Sum");
-
- auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
- auto axes = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
- auto reduce = std::make_shared<T>(data, axes);
-
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
- auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
- if (!reduce) {
- return false;
- }
+ngraph::matcher_pass_callback convert_reduce_to_pooling();
- auto input = reduce->input_value(0);
+class ngraph::pass::ConvertReduceMeanToPooling: public ngraph::pass::MatcherPass {
+public:
+ ConvertReduceMeanToPooling() {
+ auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMean>(), "ConvertReduceMean");
+ register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>());
+ }
+};
- auto axes_node = reduce->input_value(1).get_node_shared_ptr();
- if (!ngraph::op::is_constant(axes_node)) {
- return false;
- }
+class ngraph::pass::ConvertReduceMaxToPooling: public ngraph::pass::MatcherPass {
+public:
+ ConvertReduceMaxToPooling() {
+ auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMax>(), "ConvertReduceMax");
+ register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>());
+ }
+};
- auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>();
- const auto input_rank = input.get_partial_shape().rank().get_length();
- // Transform negative axes into non-negative ones
- for (size_t i = 0; i < axes_vector.size(); ++i) {
- if (axes_vector[i] < 0) {
- axes_vector[i] += input_rank;
- }
- }
- std::sort(axes_vector.begin(), axes_vector.end());
+class ngraph::pass::ConvertReduceSumToPooling: public ngraph::pass::MatcherPass {
+public:
+ ConvertReduceSumToPooling() {
+ auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceSum>(), "ConvertReduceSum");
+ register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>());
+ }
+};
- // If axes are empty we just remove Reduction operation
- if (axes_vector.empty()) {
- return replace_output_update_name(reduce->output(0), input);
+template <class T>
+ngraph::matcher_pass_callback convert_reduce_to_pooling() {
+ return [](ngraph::pattern::Matcher& m) {
+ auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
+ if (!reduce) {
+ return false;
+ }
+
+ auto input = reduce->input_value(0);
+
+ auto axes_node = reduce->input_value(1).get_node_shared_ptr();
+ if (!ngraph::op::is_constant(axes_node)) {
+ return false;
+ }
+
+ auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>();
+ const auto input_rank = input.get_partial_shape().rank().get_length();
+ // Transform negative axes into non-negative ones
+ for (size_t i = 0; i < axes_vector.size(); ++i) {
+ if (axes_vector[i] < 0) {
+ axes_vector[i] += input_rank;
}
+ }
+ std::sort(axes_vector.begin(), axes_vector.end());
+
+ // If axes are empty we just remove Reduction operation
+ if (axes_vector.empty()) {
+ return replace_output_update_name(reduce->output(0), input);
+ }
+
+ // As this transformation requires static input shape we should guaranty it
+ if (input.get_partial_shape().is_dynamic()) {
+ return false;
+ }
+ auto input_shape = input.get_shape();
+
+ // If Reduce op reduces only 1 dims we replace it with Reshape
+ if (std::all_of(axes_vector.begin(), axes_vector.end(),
+ [&input_shape](const int64_t & axis) { return input_shape[axis] == 1; })) {
+ const auto reshape_shape = reduce->output(0).get_shape();
+ auto reshape = std::make_shared<ngraph::opset1::Reshape>(input,
+ ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape), true);
+
+ reshape->set_friendly_name(reduce->get_friendly_name());
+ copy_runtime_info(reduce, reshape);
+ replace_node(reduce, reshape);
+ return true;
+ }
- // As this transformation requires static input shape we should guaranty it
- if (input.get_partial_shape().is_dynamic()) {
+ // Check that axes are consecutive otherwise this transformation is not applicable
+ for (size_t i = 1; i < axes_vector.size(); ++i) {
+ if (axes_vector[i] - axes_vector[i-1] != 1) {
return false;
}
- auto input_shape = input.get_shape();
-
- // If Reduce op reduces only 1 dims we replace it with Reshape
- if (std::all_of(axes_vector.begin(), axes_vector.end(),
- [&input_shape](const int64_t & axis) { return input_shape[axis] == 1; })) {
- const auto reshape_shape = reduce->output(0).get_shape();
- auto reshape = std::make_shared<ngraph::opset1::Reshape>(input,
- ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape), true);
-
- reshape->set_friendly_name(reduce->get_friendly_name());
- copy_runtime_info(reduce, reshape);
- replace_node(reduce, reshape);
- return true;
+ }
+
+ // Check either reduction applies to spatial dimensions or not
+ bool spatial_dims_reduction(true);
+ size_t reduction_dims_count = 1;
+ for (auto& axis : axes_vector) {
+ reduction_dims_count *= input_shape[axis];
+ if (axis <= 1) {
+ spatial_dims_reduction = false;
}
-
- // Check that axes are consecutive otherwise this transformation is not applicable
- for (size_t i = 1; i < axes_vector.size(); ++i) {
- if (axes_vector[i] - axes_vector[i-1] != 1) {
- return false;
+ }
+
+ /*
+ * Prepare default attributes for Pooling operation
+ * pads_begin/pads_end - should be zeros as we don't need any padding
+ * stride - should be filled with ones
+ * kernel - depends on Reduction operation axes
+ *
+ * Also here we decide should we use Reshapes before and after Pooling
+ * shape_begin - if not empty indicates that we need a Reshape before Pooling
+ * shape_end - if not empty indicates that we need a Reshape after Pooling
+ */
+
+ ngraph::Strides strides;
+ ngraph::Shape pads_begin, pads_end, kernel, shape_begin, shape_end;
+
+ if (!spatial_dims_reduction || input_shape.size() != 4) {
+ // In case if reduction applies not to spatial dimensions
+ // we have to fit it into 4D Pooling
+ size_t dims_prod = 1, dims_begin = 1, dims_end = 1;
+ for (size_t i = 0; i < input_shape.size(); ++i) {
+ if (i < *axes_vector.begin()) {
+ dims_begin *= input_shape[i];
+ } else if (i >= axes_vector.front() && i <= axes_vector.back()) {
+ dims_prod *= input_shape[i];
+ } else {
+ dims_end *= input_shape[i];
}
}
-
- // Check either reduction applies to spatial dimensions or not
- bool spatial_dims_reduction(true);
- size_t reduction_dims_count = 1;
+ // The batch dimenstion is repositioned in the shape
+ // only in case of batch dimension reduction
+ shape_begin.assign({dims_begin, 1, dims_prod, dims_end});
+ shape_end = reduce->output(0).get_shape();
+ strides.assign({1, 1});
+ pads_begin.assign({0, 0});
+ pads_end.assign({0, 0});
+ kernel.assign({dims_prod, 1});
+ } else {
+ for (size_t i = 0; i < input_shape.size() - 2; ++i) {
+ strides.push_back(1);
+ pads_begin.push_back(0);
+ pads_end.push_back(0);
+ kernel.push_back(1);
+ }
for (auto& axis : axes_vector) {
- reduction_dims_count *= input_shape[axis];
- if (axis <= 1) {
- spatial_dims_reduction = false;
- }
+ kernel[axis-2] = input_shape[axis];
}
-
- /*
- * Prepare default attributes for Pooling operation
- * pads_begin/pads_end - should be zeros as we don't need any padding
- * stride - should be filled with ones
- * kernel - depends on Reduction operation axes
- *
- * Also here we decide should we use Reshapes before and after Pooling
- * shape_begin - if not empty indicates that we need a Reshape before Pooling
- * shape_end - if not empty indicates that we need a Reshape after Pooling
- */
-
- Strides strides;
- Shape pads_begin, pads_end, kernel, shape_begin, shape_end;
-
- if (!spatial_dims_reduction || input_shape.size() != 4) {
- // In case if reduction applies not to spatial dimensions
- // we have to fit it into 4D Pooling
- size_t dims_prod = 1, dims_begin = 1, dims_end = 1;
- for (size_t i = 0; i < input_shape.size(); ++i) {
- if (i < *axes_vector.begin()) {
- dims_begin *= input_shape[i];
- } else if (i >= axes_vector.front() && i <= axes_vector.back()) {
- dims_prod *= input_shape[i];
- } else {
- dims_end *= input_shape[i];
- }
- }
- // The batch dimenstion is repositioned in the shape
- // only in case of batch dimension reduction
- shape_begin.assign({dims_begin, 1, dims_prod, dims_end});
+ if (!reduce->get_keep_dims()) {
shape_end = reduce->output(0).get_shape();
- strides.assign({1, 1});
- pads_begin.assign({0, 0});
- pads_end.assign({0, 0});
- kernel.assign({dims_prod, 1});
- } else {
- for (size_t i = 0; i < input_shape.size() - 2; ++i) {
- strides.push_back(1);
- pads_begin.push_back(0);
- pads_end.push_back(0);
- kernel.push_back(1);
- }
- for (auto& axis : axes_vector) {
- kernel[axis-2] = input_shape[axis];
- }
- if (!reduce->get_keep_dims()) {
- shape_end = reduce->output(0).get_shape();
- }
- }
-
- /*
- * ReduceMean => AvgPool
- * AvgPool->Reshape (in case if keep_dims=False)
- * Reshape->AvgPool->Reshape (in case if axes doesn't match spatial dims)
-
- * ReduceMax => MaxPool
- * MaxPool->Reshape (in case if keep_dims=False)
- * Reshape->MaxPool->Reshape (in case if axes doesn't match spatial dims)
- *
- * ReduceSum => AvgPool->Multiply
- * AvgPool->Multiply->Reshape (in case if keep_dims=False)
- * Reshape->AvgPool->Multiply->Reshape (in case if axes doesn't match spatial dims)
- *
- * Note: some of reshape nodes can be optimized if they do nothing.
- */
- NodeVector new_ops;
-
- if (!shape_begin.empty() && shape_begin != input.get_shape()) {
- input = std::make_shared<ngraph::opset1::Reshape>(input, opset1::Constant::create(element::i64, Shape{shape_begin.size()}, shape_begin), true);
- input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/reshape_begin");
- new_ops.push_back(input.get_node_shared_ptr());
}
-
- if (std::is_same<T, ngraph::opset1::ReduceMean>()) {
- input = std::make_shared<ngraph::opset1::AvgPool>(input,
- strides,
- pads_begin,
- pads_end,
- kernel,
- true,
- op::RoundingType::FLOOR);
-
- input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
- new_ops.push_back(input.get_node_shared_ptr());
- } else if (std::is_same<T, ngraph::opset1::ReduceMax>()) {
- input = std::make_shared<ngraph::opset1::MaxPool>(input,
- strides,
- pads_begin,
- pads_end,
- kernel,
- op::RoundingType::FLOOR);
-
- input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
- new_ops.push_back(input.get_node_shared_ptr());
- } else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
- input = std::make_shared<ngraph::opset1::AvgPool>(input,
- strides,
- pads_begin,
- pads_end,
- kernel,
- true,
- op::RoundingType::FLOOR);
-
- input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
- new_ops.push_back(input.get_node_shared_ptr());
-
- input = std::make_shared<ngraph::opset1::Multiply>(input,
- opset1::Constant::create(reduce->input(0).get_element_type(), Shape{1}, {reduction_dims_count}));
- input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
- new_ops.push_back(input.get_node_shared_ptr());
- } else {
- return false;
- }
-
- if (!shape_end.empty() && shape_end != input.get_shape()) {
- input = std::make_shared<ngraph::opset1::Reshape>(input, opset1::Constant::create(element::i64, Shape{shape_end.size()}, shape_end), true);
- new_ops.push_back(input.get_node_shared_ptr());
- }
- input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
- copy_runtime_info(reduce, new_ops);
- reduce->output(0).replace(input);
- return true;
- };
-
- auto m = std::make_shared<ngraph::pattern::Matcher>(reduce, "ConvertReduceToPooling");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
- }
+ }
+
+ /*
+ * ReduceMean => AvgPool
+ * AvgPool->Reshape (in case if keep_dims=False)
+ * Reshape->AvgPool->Reshape (in case if axes doesn't match spatial dims)
+
+ * ReduceMax => MaxPool
+ * MaxPool->Reshape (in case if keep_dims=False)
+ * Reshape->MaxPool->Reshape (in case if axes doesn't match spatial dims)
+ *
+ * ReduceSum => AvgPool->Multiply
+ * AvgPool->Multiply->Reshape (in case if keep_dims=False)
+ * Reshape->AvgPool->Multiply->Reshape (in case if axes doesn't match spatial dims)
+ *
+ * Note: some of reshape nodes can be optimized if they do nothing.
+ */
+ ngraph::NodeVector new_ops;
+
+ if (!shape_begin.empty() && shape_begin != input.get_shape()) {
+ input = std::make_shared<ngraph::opset1::Reshape>(input,
+ ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_begin.size()}, shape_begin), true);
+ input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/reshape_begin");
+ new_ops.push_back(input.get_node_shared_ptr());
+ }
+
+ if (std::is_same<T, ngraph::opset1::ReduceMean>()) {
+ input = std::make_shared<ngraph::opset1::AvgPool>(input,
+ strides,
+ pads_begin,
+ pads_end,
+ kernel,
+ true,
+ ngraph::op::RoundingType::FLOOR);
+
+ input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
+ new_ops.push_back(input.get_node_shared_ptr());
+ } else if (std::is_same<T, ngraph::opset1::ReduceMax>()) {
+ input = std::make_shared<ngraph::opset1::MaxPool>(input,
+ strides,
+ pads_begin,
+ pads_end,
+ kernel,
+ ngraph::op::RoundingType::FLOOR);
+
+ input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
+ new_ops.push_back(input.get_node_shared_ptr());
+ } else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
+ input = std::make_shared<ngraph::opset1::AvgPool>(input,
+ strides,
+ pads_begin,
+ pads_end,
+ kernel,
+ true,
+ ngraph::op::RoundingType::FLOOR);
+
+ input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
+ new_ops.push_back(input.get_node_shared_ptr());
+
+ input = std::make_shared<ngraph::opset1::Multiply>(input,
+ ngraph::opset1::Constant::create(reduce->input(0).get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
+ input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
+ new_ops.push_back(input.get_node_shared_ptr());
+ } else {
+ return false;
+ }
+
+ if (!shape_end.empty() && shape_end != input.get_shape()) {
+ input = std::make_shared<ngraph::opset1::Reshape>(input,
+ ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_end.size()}, shape_end), true);
+ new_ops.push_back(input.get_node_shared_ptr());
+ }
+ input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
+ copy_runtime_info(reduce, new_ops);
+ reduce->output(0).replace(input);
+ return true;
+ };
}
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertSpaceToBatch: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertSpaceToBatch: public ngraph::pass::GraphRewrite {
public:
- ConvertSpaceToBatch() : GraphRewrite(), PassParam() {
+ ConvertSpaceToBatch() : GraphRewrite() {
// convert_space_to_batch();
convert_space_to_batch_by_elements();
}
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::MatcherPass {
public:
- ConvertSpaceToDepth() : GraphRewrite(), PassParam() {
- convert();
- }
-
-private:
- void convert();
+ ConvertSpaceToDepth();
};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::ConvertSubtract: public ngraph::pass::GraphRewrite {
+class ngraph::pass::ConvertSubtract: public ngraph::pass::MatcherPass {
public:
- ConvertSubtract() : GraphRewrite() {
- convert_subtract();
- }
-
-private:
- void convert_subtract();
+ ConvertSubtract();
};
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
-#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
* p.run_on_function(f);
*
*/
-class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
+
+class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite {
public:
- DepthToSpaceFusion() : GraphRewrite(), PassParam() {
+ DepthToSpaceFusion() : GraphRewrite() {
depth_to_space_fusion();
}
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+#include <transformations_visibility.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API LinOpSequenceFusion;
+class TRANSFORMATIONS_API AddMultiplyFusion;
+class TRANSFORMATIONS_API AddAddFusion;
+class TRANSFORMATIONS_API MultiplyMultiplyFusion;
+
+} // namespace pass
+} // namespace ngraph
+
+class ngraph::pass::LinOpSequenceFusion: public ngraph::pass::GraphRewrite {
+public:
+ LinOpSequenceFusion() {
+ add_matcher<ngraph::pass::AddMultiplyFusion>();
+ add_matcher<ngraph::pass::AddAddFusion>();
+ add_matcher<ngraph::pass::MultiplyMultiplyFusion>();
+ }
+};
+
+class ngraph::pass::AddMultiplyFusion: public ngraph::pass::MatcherPass {
+public:
+ AddMultiplyFusion();
+};
+
+class ngraph::pass::AddAddFusion: public ngraph::pass::MatcherPass {
+public:
+ AddAddFusion();
+};
+
+class ngraph::pass::MultiplyMultiplyFusion: public ngraph::pass::MatcherPass {
+public:
+ MultiplyMultiplyFusion();
+};
} // namespace pass
} // namespace ngraph
-class ngraph::pass::PullTransposeThroughFQUp: public ngraph::pass::GraphRewrite {
+class ngraph::pass::PullTransposeThroughFQUp: public ngraph::pass::MatcherPass {
public:
- PullTransposeThroughFQUp() : GraphRewrite() {
- pull_transpose_through_fq();
- }
-
-private:
- void pull_transpose_through_fq();
+ PullTransposeThroughFQUp();
};
+++ /dev/null
-// Copyright (C) 2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#pragma once
-
-#include <memory>
-#include <transformations_visibility.hpp>
-#include <ngraph/pass/graph_rewrite.hpp>
-
-namespace ngraph {
-namespace pass {
-
-class TRANSFORMATIONS_API PassParam;
-
-} // namespace pass
-} // namespace ngraph
-
-class ngraph::pass::PassParam {
-public:
- using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
-
- explicit PassParam(const param_callback & callback = getDefaultCallback()) : transformation_callback(callback) {}
-
- void setCallback(const param_callback & callback) {
- transformation_callback = callback;
- }
-
- static param_callback getDefaultCallback() {
- return [](const std::shared_ptr<const Node> &) -> bool {
- return false;
- };
- }
-
-protected:
- param_callback transformation_callback;
-};
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::BatchNormDecomposition::batch_norm_decomposition() {
+ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
Shape shape{2, 2, 1, 1};
auto input = make_shared<pattern::op::Label>(element::f32, shape);
auto mean_shape = Shape{2};
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "BatchNormDecomposition");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
- ngraph::pass::Manager CommonOptimizations;
- std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
+ ngraph::pass::Manager manager;
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(CommonOptimizations.register_pass<NAMESPACE::NAME>());
-#include <transformations/common_optimizations/common_optimizations_tbl.hpp>
-#undef NGRAPH_PASS
+ // This pass must be called first in pipeline
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+ manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
+ manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
+ manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+ manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
+ manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
- for (auto & t : transforms) {
- if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
- t_param->setCallback(transformation_callback);
- }
- }
- CommonOptimizations.run_passes(f);
+ manager.set_callback(m_transformation_callback);
+ manager.run_passes(f);
return true;
}
+++ /dev/null
-// Copyright (C) 2018-2020 Intel Corporation
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#include "transformations/constant_eltwise_reduction.hpp"
-
-#include <memory>
-#include <vector>
-
-#include <ngraph/opsets/opset1.hpp>
-
-template <typename T>
-std::shared_ptr<ngraph::op::Constant> constant_reduction(const std::shared_ptr<ngraph::op::Constant>& const_node) {
- std::vector<T> data = const_node->get_vector<T>();
- // TODO: implement this function after eltwise broadcast support will be added
- return nullptr;
-}
-
-ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
- // Check that eltwise operation either Add or Multiply
- auto eltwise_node = m.get_match_root();
- if (!std::dynamic_pointer_cast<ngraph::opset1::Add>(eltwise_node) &&
- !std::dynamic_pointer_cast<ngraph::opset1::Multiply>(eltwise_node)) {
- return false;
- }
-
- for (const auto& input : eltwise_node->inputs()) {
- const auto& inputLayer = input.get_source_output().get_node_shared_ptr();
- auto const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(inputLayer);
- if (!const_node) continue;
-
- std::shared_ptr<ngraph::opset1::Constant> result = nullptr;
-
- ngraph::element::Type elem_type = const_node->get_element_type();
- switch (elem_type) {
- case ngraph::element::Type_t::f32:
- result = constant_reduction<float>(const_node);
- break;
- case ngraph::element::Type_t::i16:
- case ngraph::element::Type_t::f16:
- result = constant_reduction<short>(const_node);
- break;
- case ngraph::element::Type_t::u8:
- result = constant_reduction<uint8_t>(const_node);
- break;
- case ngraph::element::Type_t::i8:
- result = constant_reduction<int8_t>(const_node);
- break;
- case ngraph::element::Type_t::i32:
- result = constant_reduction<int32_t>(const_node);
- break;
- default:
- return false;
- }
- if (result) {
- ngraph::replace_node(inputLayer, std::dynamic_pointer_cast<ngraph::Node>(result));
- std::cout << "Successful constant_eltwise reduction" << std::endl;
- }
- }
-
- return true;
-};
-
-void ngraph::pass::ConstantEltwiseReduction::constant_multiply_reduction() {
- Shape shape {2, 2, 1, 1};
- auto constant1 = std::make_shared<pattern::op::Label>(element::f32, shape);
- auto constant2 = std::make_shared<pattern::op::Label>(element::f32, shape);
- auto eltwise = std::make_shared<ngraph::opset1::Multiply>(constant1, constant2);
-
- auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "CPUFusion.ConstantMultiplyReduction");
- this->add_matcher(m, callback);
-}
-
-void ngraph::pass::ConstantEltwiseReduction::constant_add_reduction() {
- Shape shape {2, 2, 1, 1};
- auto constant1 = std::make_shared<pattern::op::Label>(element::f32, shape);
- auto constant2 = std::make_shared<pattern::op::Label>(element::f32, shape);
- auto eltwise = std::make_shared<ngraph::opset1::Add>(constant1, constant2);
-
- auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "CPUFusion.ConstantAddReduction");
- this->add_matcher(m, callback);
-}
auto data = batch_to_space->input_value(0);
auto data_shape = data.get_shape();
- if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
+ if (m_transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
return false;
}
auto block = batch_to_space->input_value(1);
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertBroadcastToTiles::convert_broadcast_to_tiles() {
- auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape {1});
- auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape {1});
- auto axs = std::make_shared<pattern::op::Label>(element::i64, Shape {1});
- auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(weights, shp, axs);
+ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
+ auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(broadcast, "ConvertBroadcastToTile");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() {
- auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
- auto dts_node = std::make_shared<ngraph::opset1::DepthToSpace>(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST);
+ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
+ auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
- if (!dts_node || transformation_callback(dts_node)) {
+ if (!dts_node || m_transformation_callback(dts_node)) {
return false;
}
};
auto m = std::make_shared<ngraph::pattern::Matcher>(dts_node, "ConvertDepthToSpace");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertDivide::convert_divide() {
- auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto div = std::make_shared<ngraph::opset1::Divide>(input0, input1);
+ngraph::pass::ConvertDivide::ConvertDivide() {
+ auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide> (m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertMinimum::convert_minimum() {
- auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto minimum = std::make_shared<ngraph::opset1::Minimum>(input0, input1);
+ngraph::pass::ConvertMinimum::ConvertMinimum() {
+ auto minimum = ngraph::pattern::wrap_type<opset1::Minimum>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto minimum = std::dynamic_pointer_cast<ngraph::opset1::Minimum> (m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(minimum, "ConvertMinimum");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertMod::convert_mod() {
- auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
- auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
- auto mod = std::make_shared<ngraph::opset1::Mod>(input0, input1);
+ngraph::pass::ConvertMod::ConvertMod() {
+ auto mod = ngraph::pattern::wrap_type<opset1::Mod>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto mod = std::dynamic_pointer_cast<ngraph::opset1::Mod> (m.get_match_root());
if (!mod) {
return false;
const auto divisor = std::make_shared<opset1::Abs>(mod->input_value(1));
// truncated(a / b)
- auto div = std::make_shared<opset1::Divide>(dividend, divisor);
+ auto div = register_new_node<opset1::Divide>(dividend, divisor);
auto convert_to_i64 = std::make_shared<opset1::Convert>(div, ngraph::element::i64);
auto convert = std::make_shared<opset1::Convert>(convert_to_i64, dividend_et);
// truncated(a / b) * b
auto multiplication = std::make_shared<opset1::Multiply>(convert, divisor);
// a mod b = a - truncated(a / b) * b
- auto sub = std::make_shared<opset1::Subtract>(dividend, multiplication);
+ auto sub = register_new_node<opset1::Subtract>(dividend, multiplication);
// apply sign of dividend
auto mul = std::make_shared<opset1::Multiply>(dividend_sign, sub);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mod, "ConvertMod");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertNegative::convert_negative() {
- auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
- auto neg = std::make_shared<ngraph::opset1::Negative>(input);
+ngraph::pass::ConvertNegative::ConvertNegative() {
+ auto neg = ngraph::pattern::wrap_type<ngraph::opset1::Negative>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto neg = std::dynamic_pointer_cast<ngraph::opset1::Negative> (m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(neg, "ConvertNegative");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_opset1_to_legacy/conv_bias_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+#include <ngraph_ops/convolution_ie.hpp>
+#include <ngraph_ops/deconvolution_ie.hpp>
+
+template <class Conv>
+ngraph::graph_rewrite_callback get_callback() {
+ ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
+ auto eltwise = m.get_match_root();
+
+ std::shared_ptr<ngraph::opset1::Constant> m_const;
+ std::shared_ptr<Conv> m_conv;
+ // FIXME: use auto [m_conv, m_const] when C++17 is available
+ std::tie(m_conv, m_const) = parse_eltwise_inputs<Conv, ngraph::opset1::Constant>(eltwise);
+ if (!m_conv || !m_const) {
+ return false;
+ }
+
+ // TODO: check that constant can be scalar and do not match [1, C, 1, 1] layout
+ const auto constant_shape = m_const->get_shape();
+ const auto output_pshape = m_conv->get_output_partial_shape(0);
+
+ if (output_pshape.rank().is_dynamic() || output_pshape[1].is_dynamic()) {
+ return false;
+ }
+
+ const auto channel_dim = output_pshape[1].get_length();
+
+ size_t constant_size = std::accumulate(constant_shape.begin(), constant_shape.end(), 1, std::multiplies<size_t>());
+ if (constant_size != channel_dim) {
+ return false;
+ }
+
+ ngraph::Output<ngraph::Node> constant(m_const);
+
+ if (constant_shape.size() > 1) {
+ constant = std::make_shared<ngraph::opset1::Reshape>(constant,
+ ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {channel_dim}), true);
+ }
+
+ if (m_conv->output(0).get_target_inputs().size() != 1) {
+ return false;
+ }
+
+ ngraph::Output<ngraph::Node> new_conv, new_weights, new_bias;
+ if (std::dynamic_pointer_cast<ngraph::opset1::Add>(eltwise)) {
+ // Fuse: ConvolutionIE/DeconvolutionIE->Add
+ if (m_conv->inputs().size() == 2) {
+ new_bias = constant;
+ } else {
+ new_bias = std::make_shared<ngraph::opset1::Add>(constant, m_conv->input_value(2));
+ }
+ new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), m_conv->input_value(1), new_bias});
+ } else if (std::is_same<Conv, ngraph::op::ConvolutionIE>() && std::dynamic_pointer_cast<ngraph::opset1::Multiply>(eltwise)) {
+ // Fuse: ConvolutionIE->Mul
+ auto weights_shape = m_conv->input(1).get_shape();
+
+ ngraph::Shape const_shape(weights_shape.size(), 1);
+ const_shape[0] = weights_shape[0];
+
+ auto const_reshape = std::make_shared<ngraph::opset1::Reshape>(constant,
+ ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape), true);
+ new_weights = std::make_shared<ngraph::opset1::Multiply> (m_conv->input_value(1), const_reshape);
+ if (m_conv->inputs().size() == 2) {
+ new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights});
+ } else {
+ auto bias_reshape = std::make_shared<ngraph::opset1::Reshape>(constant,
+ ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {weights_shape[0]}), true);
+ new_bias = std::make_shared<ngraph::opset1::Multiply>(bias_reshape, constant);
+ new_conv = m_conv->clone_with_new_inputs({m_conv->input_value(0), new_weights, new_bias});
+ }
+ } else {
+ return false;
+ }
+
+ ngraph::copy_runtime_info({m_conv, eltwise}, new_conv.get_node_shared_ptr());
+ new_conv.get_node_shared_ptr()->set_friendly_name(m.get_match_root()->get_friendly_name());
+ ngraph::replace_node(m.get_match_root(), new_conv.get_node_shared_ptr());
+ return true;
+ };
+ return callback;
+}
+
+ngraph::pass::ConvAddFusion::ConvAddFusion() {
+ auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>();
+ auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
+
+ matcher_pass_callback callback = get_callback<op::ConvolutionIE>();
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(add, "ConvAddFusion");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::ConvMultiplyFusion::ConvMultiplyFusion() {
+ auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>();
+ auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, std::make_shared<pattern::op::Label>()});
+
+ matcher_pass_callback callback = get_callback<op::ConvolutionIE>();
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(add, "ConvMultiplyFusion");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::DeconvAddFusion::DeconvAddFusion() {
+ auto conv = ngraph::pattern::wrap_type<op::DeconvolutionIE>();
+ auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
+
+ matcher_pass_callback callback = get_callback<op::DeconvolutionIE>();
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(add, "DeconvAddFusion");
+ register_matcher(m, callback);
+}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph_ops/lstm_cell_ie.hpp>
#include <ngraph_ops/gru_cell_ie.hpp>
#include <ngraph_ops/rnn_cell_ie.hpp>
-void ngraph::pass::ConvertCellsToCellsIE::convert_lstm_cell() {
- // placeholders
- auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); // X
- auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // initial_hidden_state
- auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // initial_cell_state
- auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1}); // W
- auto input_4 = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1}); // R
- auto input_5 = std::make_shared<pattern::op::Label>(element::f32, Shape{4}); // B
+ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() {
+ auto lstm_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset1::LSTMCell>();
- auto lstm_cell_ngraph = std::make_shared<ngraph::opset1::LSTMCell>(input_0, input_1, input_2, input_3, input_4, input_5, 1);
-
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto lstm_cell = std::dynamic_pointer_cast<ngraph::opset1::LSTMCell> (m.get_match_root());
if (!lstm_cell) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_cell_ngraph, "ConvertLSTMCellToLSTMCellIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-void ngraph::pass::ConvertCellsToCellsIE::convert_gru_cell() {
- // placeholders
- auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // X
- auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // initial_hidden_state
- auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 1}); // W
- auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 1}); // R
- auto input_4 = std::make_shared<pattern::op::Label>(element::f32, Shape{3}); // B
-
- auto gru_cell_ngraph = std::make_shared<ngraph::opset3::GRUCell>(input_0, input_1, input_2, input_3, input_4, 1);
+ngraph::pass::ConvertGRUCellMatcher::ConvertGRUCellMatcher() {
+ auto gru_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset3::GRUCell>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto gru_cell = std::dynamic_pointer_cast<ngraph::opset3::GRUCell> (m.get_match_root());
if (!gru_cell) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gru_cell_ngraph, "ConvertGRUCellToGRUCellIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-void ngraph::pass::ConvertCellsToCellsIE::convert_rnn_cell() {
- // placeholders
- auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // X
- auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // initial_hidden_state
- auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // W
- auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1}); // R
- auto input_4 = std::make_shared<pattern::op::Label>(element::f32, Shape{1}); // B
-
- auto rnn_cell_ngraph = std::make_shared<ngraph::opset3::RNNCell>(input_0, input_1, input_2, input_3, input_4, 1);
+ngraph::pass::ConvertRNNCellMatcher::ConvertRNNCellMatcher() {
+ auto rnn_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset3::RNNCell>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto rnn_cell = std::dynamic_pointer_cast<ngraph::opset3::RNNCell> (m.get_match_root());
if (!rnn_cell) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_cell_ngraph, "ConvertRNNCellToRNNCellIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph_ops/convolution_ie.hpp>
#include <ngraph_ops/deconvolution_ie.hpp>
-void ngraph::pass::ConvertConvolutions::convert_convolution() {
- auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
- pattern::has_class<opset1::Convolution>());
+#include <ngraph/pattern/op/wrap_type.hpp>
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ngraph::pass::ConvertConvolution::ConvertConvolution() {
+ auto conv = ngraph::pattern::wrap_type<opset1::Convolution>();
+
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto conv = std::dynamic_pointer_cast<ngraph::opset1::Convolution> (m.get_match_root());
if (!conv) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvertConvolution");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-void ngraph::pass::ConvertConvolutions::convert_group_convolution() {
- auto gconv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
- pattern::has_class<opset1::GroupConvolution>());
+ngraph::pass::ConvertGroupConvolution::ConvertGroupConvolution() {
+ auto gconv = ngraph::pattern::wrap_type<opset1::GroupConvolution>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto gconv = std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution> (m.get_match_root());
if (!gconv) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ConvertGroupConvolution");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-void ngraph::pass::ConvertConvolutions::convert_convolution_backprop_data() {
- auto conv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
- pattern::has_class<opset1::ConvolutionBackpropData>());
+ngraph::pass::ConvertDeconvolution::ConvertDeconvolution() {
+ auto conv = ngraph::pattern::wrap_type<opset1::ConvolutionBackpropData>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto deconv = std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData> (m.get_match_root());
if (!deconv) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "ConvertConvolutionBackpropData");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-void ngraph::pass::ConvertConvolutions::convert_group_convolution_backprop_data() {
- auto gconv = std::make_shared<pattern::op::Label>(element::f32, Shape{},
- pattern::has_class<opset1::GroupConvolutionBackpropData>());
+ngraph::pass::ConvertGroupDeconvolution::ConvertGroupDeconvolution() {
+ auto gconv = ngraph::pattern::wrap_type<opset1::GroupConvolutionBackpropData>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto gconv = std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData> (m.get_match_root());
if (!gconv) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ConvertGroupConvolutionBackpropData");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertGatherToGatherIE::convert_gather_to_gather_ie() {
- auto gather = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::Gather>());
+ngraph::pass::ConvertGatherToGatherIEMatcher::ConvertGatherToGatherIEMatcher() {
+ auto gather = ngraph::pattern::wrap_type<opset1::Gather>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto gather = std::dynamic_pointer_cast<ngraph::opset1::Gather>(m.get_match_root());
if (!gather) {
return false;
};
auto m1 = std::make_shared<ngraph::pattern::Matcher>(gather, "ConvertGatherToGatherIE");
- this->add_matcher(m1, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m1, callback);
}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertGatherTreeToGatherTreeIE::convert() {
+ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher::ConvertGatherTreeToGatherTreeIEMatcher() {
auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1});
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1});
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
auto input3 = std::make_shared<pattern::op::Label>(element::i64, Shape{});
auto gt = std::make_shared<ngraph::opset1::GatherTree>(input0, input1, input2, input3);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto gt = std::dynamic_pointer_cast<ngraph::opset1::GatherTree> (m.get_match_root());
if (!gt) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gt, "ConvertGatherTreeToGatherTreeIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());
- if (!gelu || transformation_callback(gelu))
+ if (!gelu || m_transformation_callback(gelu))
return false;
auto input = gelu->input_value(0);
auto input_type = input.get_element_type();
#include <ngraph_ops/hard_sigmoid_ie.hpp>
-void ngraph::pass::ConvertHardSigmoidToHardSigmoidIE::convert_hard_sigmoid() {
+ngraph::pass::ConvertHardSigmoidToLegacyMatcher::ConvertHardSigmoidToLegacyMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto node = std::make_shared<ngraph::opset1::HardSigmoid>(input_0, input_1, input_2);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto hard_sigmoid = std::dynamic_pointer_cast<ngraph::opset1::HardSigmoid> (m.get_match_root());
if (!hard_sigmoid) {
return false;
return true;
};
- auto m = std::make_shared<ngraph::pattern::Matcher>(node, "ConvertHardSigmoidToHardSigmoidIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ auto m = std::make_shared<ngraph::pattern::Matcher>(node, "ConvertHardSigmoidToLegacy");
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph_ops/interp.hpp>
-void ngraph::pass::ConvertInterpolateToInterpOrResample::convert_interpolate_to_interp_or_resample() {
+ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher::ConvertInterpolateToInterpOrResampleMatcher() {
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
auto interpolate = std::make_shared<ngraph::opset1::Interpolate>(data, shp, ngraph::op::InterpolateAttrs());
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto interpolate = std::dynamic_pointer_cast<ngraph::opset1::Interpolate> (m.get_match_root());
if (!interpolate)
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(interpolate, "ConvertInterpolateToInterpOrResample");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph_ops/lrn_ie.hpp>
-void ngraph::pass::ConvertLRNToLRNIE::convert_lrn() {
+ngraph::pass::ConvertLRNToLegacyMatcher::ConvertLRNToLegacyMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
auto lrn = std::make_shared<ngraph::opset1::LRN>(input_0, input_1, 1, 1, 1, 1);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto lrn = std::dynamic_pointer_cast<ngraph::opset1::LRN> (m.get_match_root());
if (!lrn) {
return false;
return true;
};
- auto m = std::make_shared<ngraph::pattern::Matcher>(lrn, "ConvertLRNToLRNIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ auto m = std::make_shared<ngraph::pattern::Matcher>(lrn, "ConvertLRNToLegacy");
+ this->register_matcher(m, callback);
}
#include <transformations/utils/utils.hpp>
-void ngraph::pass::ConvertMatMulToFCorGemm::convert_matmul() {
+ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
if (!matmul) {
return false;
* order will be [0, 1, 3, 2] that emulates transpose_a or transpose_b attribute.
*/
- auto create_transpose = [](Output<Node> node, const std::string& transpose_name) -> std::shared_ptr<Node> {
+ auto create_transpose = [this](Output<Node> node, const std::string& transpose_name) -> std::shared_ptr<Node> {
Shape output_shape = node.get_node_shared_ptr()->get_shape();
std::vector<size_t> transpose_order(output_shape.size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
- auto transpose = std::make_shared<ngraph::opset1::Transpose>(
+ auto transpose = register_new_node<ngraph::opset1::Transpose>(
node, opset1::Constant::create(element::i64, Shape {transpose_order.size()}, transpose_order));
transpose->set_friendly_name(transpose_name);
return transpose;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFCorGemm");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include "transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.hpp"
-void ngraph::pass::ConvertNMS4ToLegacy::convert_nms4_to_legacy() {
+ngraph::pass::ConvertNMS4ToLegacyMatcher::ConvertNMS4ToLegacyMatcher() {
auto boxes = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1000, 4});
auto scores = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1000});
auto max_output_boxes_per_class = ngraph::opset4::Constant::create(element::i64, Shape{}, {10});
auto nms = std::make_shared<ngraph::opset4::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto nms_4 = std::dynamic_pointer_cast<ngraph::opset4::NonMaxSuppression>(m.get_match_root());
if (!nms_4) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMS4ToNMSLegacy");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include <ngraph_ops/nms_ie.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertNMSToNMSIE::convert_nms_to_nms_ie() {
- auto nms = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset1::NonMaxSuppression>());
+ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() {
+ auto nms = ngraph::pattern::wrap_type<opset1::NonMaxSuppression>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto nms = std::dynamic_pointer_cast<opset1::NonMaxSuppression>(m.get_match_root());
if (!nms) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMSToNMSIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include "ngraph_ops/normalize_ie.hpp"
-void ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::convert_normalize_l2_with_mul() {
+ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::ConvertNormalizeL2WithMulToNormalizeIE() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto axis = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "CPUFusion.ConvertNormalizeL2WithMulToNormalizeIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-void ngraph::pass::ConvertNormalizeL2ToNormalizeIE::convert_normalize_l2() {
+ngraph::pass::ConvertNormalizeL2ToLegacyMatcher::ConvertNormalizeL2ToLegacyMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto axis = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
auto normalize = std::make_shared<ngraph::op::NormalizeL2>(input_0, axis, 0, ngraph::op::EpsMode::ADD);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto normalize = std::dynamic_pointer_cast<ngraph::op::NormalizeL2> (m.get_match_root());
if (!normalize) return false;
return true;
};
- auto m = std::make_shared<ngraph::pattern::Matcher>(normalize, "CPUFusion.ConvertNormalizeL2ToNormalizeIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ auto m = std::make_shared<ngraph::pattern::Matcher>(normalize, "ConvertNormalizeL2ToNormalizeIE");
+ this->register_matcher(m, callback);
}
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertOneHotToOneHotIE::convert_one_hot() {
+ngraph::pass::ConvertOneHotToOneHotIEMatcher::ConvertOneHotToOneHotIEMatcher() {
auto input = std::make_shared<pattern::op::Label>(element::i32, Shape{1, 1, 1, 1});
auto depth = std::make_shared<pattern::op::Label>(element::i64, Shape{});
auto on_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto off_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto one_hot = std::make_shared<ngraph::opset1::OneHot>(input, depth, on_value, off_value, 1);
- ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto one_hot = std::dynamic_pointer_cast<ngraph::opset1::OneHot> (m.get_match_root());
if (!one_hot) {
return false;
}
- element::Type output_type = is_f16 ? element::f16 : element::f32;
-
- const auto depth_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input(1).get_source_output().get_node_shared_ptr());
- const auto on_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input(2).get_source_output().get_node_shared_ptr());
- const auto off_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input(3).get_source_output().get_node_shared_ptr());
+ const auto depth_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input_value(1).get_node_shared_ptr());
+ const auto on_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input_value(2).get_node_shared_ptr());
+ const auto off_value_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(one_hot->input_value(3).get_node_shared_ptr());
// can be converted iff inputs with depth, on/off values are constants
if (depth_node == nullptr || on_value_node == nullptr || off_value_node == nullptr) return false;
auto off_value = std::stof(off_value_node->convert_value_to_string(0));
auto one_hot_ie = std::make_shared<ngraph::op::OneHotIE>(one_hot->input_value(0),
- one_hot->get_axis(), depth_value, on_value, off_value, output_type);
+ one_hot->get_axis(), depth_value, on_value, off_value, m_output_type);
one_hot_ie->set_friendly_name(one_hot->get_friendly_name());
// insert Convert layer to cast output to a correct data type defined by the on/off values
- if (on_value_node->get_element_type() != output_type) {
+ if (on_value_node->get_element_type() != m_output_type) {
auto convert = std::make_shared<ngraph::opset1::Convert>(one_hot_ie, on_value_node->get_element_type());
convert->set_friendly_name(one_hot->get_friendly_name() + "/Convert");
ngraph::copy_runtime_info(one_hot, {one_hot_ie, convert});
};
auto m = std::make_shared<ngraph::pattern::Matcher>(one_hot, "ConvertOneHotToOneHotIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
-bool ngraph::pass::ConvertOneHotToOneHotIE::run_on_function(std::shared_ptr<ngraph::Function> f) {
- is_f16 = ngraph::op::util::has_f16_constants(f);
- return GraphRewrite::run_on_function(f);
-}
+void ngraph::pass::ConvertOneHotToOneHotIEMatcher::detect_output_type(const std::shared_ptr<ngraph::Function> &f) {
+ m_output_type = ngraph::op::util::has_f16_constants(f) ? element::Type_t::f16 : element::Type_t::f32;
+}
\ No newline at end of file
#include "transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp"
-#include <transformations/constant_eltwise_reduction.hpp>
#include <transformations/convert_broadcast_to_tiles.hpp>
#include <transformations/convert_opset1_to_legacy/convert_convolutions.hpp>
#include <transformations/convert_divide.hpp>
#include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
#include <transformations/pull_transpose_through_fq.hpp>
#include <transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.hpp>
+#include <transformations/lin_op_sequence_fusoin.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/manager.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+
#include <memory>
#include <vector>
bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph::Function> f) {
- ngraph::pass::Manager OpSet1ToLegacy;
+ ngraph::pass::Manager manager;
std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(OpSet1ToLegacy.register_pass<NAMESPACE::NAME>());
-#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy_tbl.hpp>
-#undef NGRAPH_PASS
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+
+ // List if Decomposition and Conversion transformations that can be
+ // applied simultaneously in a single graph traversal
+ auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
+ decomp->add_matcher<ngraph::pass::ConvertBroadcastToTiles>();
+ decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();
+ decomp->add_matcher<ngraph::pass::ConvertReduceMaxToPooling>();
+ decomp->add_matcher<ngraph::pass::ConvertReduceSumToPooling>();
+ decomp->add_matcher<ngraph::pass::ConvertMod>();
+ decomp->add_matcher<ngraph::pass::ConvertMinimum>();
+ decomp->add_matcher<ngraph::pass::ConvertSubtract>();
+ decomp->add_matcher<ngraph::pass::ConvertDivide>();
+ decomp->add_matcher<ngraph::pass::ConvertNegative>();
+ decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
+ decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
+ decomp->add_matcher<ngraph::pass::ConvertConvolution>();
+ decomp->add_matcher<ngraph::pass::ConvertGroupConvolution>();
+ decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
+ decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
+ decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
+ decomp->add_matcher<ngraph::pass::ConvertMatMulToFCorGemm>();
+ decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
+ decomp->set_name("ngraph::pass::Decompositions");
+
+ // CF is required after all decompositions
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+
+ // LinOpSequenceFusion must be executed after all decompositions
+ manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+
+ // Convolution/Deconvolution/FullyConnected fusions
+ auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
+ fusion->add_matcher<ngraph::pass::ConvAddFusion>();
+ fusion->add_matcher<ngraph::pass::ConvMultiplyFusion>();
+ fusion->add_matcher<ngraph::pass::DeconvAddFusion>();
+ fusion->add_matcher<ngraph::pass::FullyConnectedBiasFusion>();
+ fusion->set_name("ngraph::pass::Fusions");
+
+ // CF is required after fusions
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+
+ // List of passes that convert opset1 operations to legacy
+ // plus transformations that are required by InferenceEngine
+ // All this transformations can be executed simultaneously
+ auto anchor = manager.register_pass<ngraph::pass::GraphRewrite>();
+ anchor->add_matcher<ngraph::pass::ReshapeFullyConnected>();
+ anchor->add_matcher<ngraph::pass::Reshape1DConvolution>();
+ anchor->add_matcher<ngraph::pass::Reshape1DAvgPool>();
+ anchor->add_matcher<ngraph::pass::Reshape1DMaxPool>();
+ anchor->add_matcher<ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE>();
+ anchor->add_matcher<ngraph::pass::ConvertHardSigmoidToLegacyMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertProposalToLegacyMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertTileToLegacyMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertLRNToLegacyMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertPadToLegacyMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertLSTMCellMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertRNNCellMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertGRUCellMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertPowerToPowerIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertSqrtToPowerIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertPReLUToReLUIE>();
+ anchor->add_matcher<ngraph::pass::ConvertGatherToGatherIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertSeluToSeluIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
+ anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+ anchor->add_matcher<ngraph::pass::ConvertNMS4ToLegacyMatcher>();
+ anchor->set_name("ngraph::pass::ConvertOpSet1ToLegacy");
+
+ // List of final conversion transformations that must to be executed
+ // after previous group of transformations
+ manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
+ manager.register_pass<ngraph::pass::ConvertNormalizeL2ToLegacyMatcher>();
+ manager.register_pass<ngraph::pass::ConvertMulAddToScaleShiftOrPower>();
+ manager.register_pass<ngraph::pass::ConvertMulOrAddFinally>();
+
+ manager.register_pass<ngraph::pass::ConstantFolding>();
- for (auto & t : transforms) {
- if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
- t_param->setCallback(transformation_callback);
- }
- }
- OpSet1ToLegacy.run_passes(f);
+ manager.set_callback(m_transformation_callback);
+ manager.run_passes(f);
return true;
}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertPadToPadIE::convert_pad() {
- auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
- auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
- auto input_2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
- auto input_3 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
- auto pad_1 = std::make_shared<ngraph::opset1::Pad>(input_0, input_1, input_2, op::PadMode::SYMMETRIC);
- auto pad_2 = std::make_shared<ngraph::opset1::Pad>(input_0, input_1, input_2, input_3, op::PadMode::CONSTANT);
+ngraph::pass::ConvertPadToLegacyMatcher::ConvertPadToLegacyMatcher() {
+ auto m_pad = ngraph::pattern::wrap_type<ngraph::opset1::Pad>();
-
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto pad = std::dynamic_pointer_cast<ngraph::opset1::Pad> (m.get_match_root());
if (!pad) {
return false;
}
auto pad_ie = std::make_shared<ngraph::op::PadIE>(pad);
- if (pad_ie == nullptr)
- return false;
pad_ie->set_friendly_name(pad->get_friendly_name());
ngraph::copy_runtime_info(pad, pad_ie);
ngraph::replace_node(pad, pad_ie);
return true;
};
- auto m1 = std::make_shared<ngraph::pattern::Matcher>(pad_1, "ConvertPadToPadIE");
- this->add_matcher(m1, callback, PassProperty::CHANGE_DYNAMIC_STATE);
-
- auto m2 = std::make_shared<ngraph::pattern::Matcher>(pad_2, "ConvertPadToPadIE");
- this->add_matcher(m2, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ auto m = std::make_shared<ngraph::pattern::Matcher>(m_pad, "ConvertPadToLegacy");
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertPowerToPowerIE::convert_power() {
+ngraph::pass::ConvertPowerToPowerIEMatcher::ConvertPowerToPowerIEMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto power = std::make_shared<ngraph::opset1::Power>(input_0, input_1);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto power = std::dynamic_pointer_cast<ngraph::opset1::Power> (m.get_match_root());
if (!power) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(power, "ConvertPowerToPowerIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertPReLUToReLUIE::convert_prelu() {
+ngraph::pass::ConvertPReLUToReLUIE::ConvertPReLUToReLUIE() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto prelu = std::make_shared<ngraph::opset1::PRelu>(input_0, input_1);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto prelu = std::dynamic_pointer_cast<ngraph::opset1::PRelu> (m.get_match_root());
if (!prelu) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, "ConvertPReLUToReLUIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph_ops/proposal_ie.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertProposalToProposalIE::convert_proposal() {
+ngraph::pass::ConvertProposalToLegacyMatcher::ConvertProposalToLegacyMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto proposal = std::make_shared<ngraph::opset1::Proposal>(input_0, input_1, input_2, attr);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto proposal = std::dynamic_pointer_cast<ngraph::opset1::Proposal> (m.get_match_root());
if (!proposal) {
return true;
};
- auto m = std::make_shared<ngraph::pattern::Matcher>(proposal, "CPUFusion.ConvertProposalToProposalIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ auto m = std::make_shared<ngraph::pattern::Matcher>(proposal, "ConvertProposalToProposalIE");
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertSeluToSeluIE::convert_selu() {
+ngraph::pass::ConvertSeluToSeluIEMatcher::ConvertSeluToSeluIEMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto input_2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto selu = std::make_shared<ngraph::opset1::Selu>(input_0, input_1, input_2);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto selu = std::dynamic_pointer_cast<ngraph::opset1::Selu> (m.get_match_root());
if (!selu) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(selu, "ConvertSeluToSeluIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertSqrtToPowerIE::convert_sqrt() {
+ngraph::pass::ConvertSqrtToPowerIEMatcher::ConvertSqrtToPowerIEMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto sqrt = std::make_shared<ngraph::opset1::Sqrt>(input_0);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto sqrt = std::dynamic_pointer_cast<ngraph::opset1::Sqrt>(m.get_match_root());
if (!sqrt) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(sqrt, "ConvertPowerToPowerIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include <ngraph_ops/crop_ie.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertStridedSliceToCrop::convert_strided_slice_to_crop() {
+ngraph::pass::ConvertStridedSliceToCropMatcher::ConvertStridedSliceToCropMatcher() {
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto m_begin = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
auto m_end = std::make_shared<pattern::op::Label>(element::i64, Shape{2});
std::vector<int64_t> end_mask = {0, 0, 0, 0};
auto m_slice = std::make_shared<ngraph::opset1::StridedSlice>(data, m_begin, m_end, m_stride, begin_mask, end_mask);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto slice = std::dynamic_pointer_cast<ngraph::opset1::StridedSlice> (m.get_match_root());
if (!slice) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(m_slice, "ConvertStridedSliceToCrop");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
#include <ngraph_ops/tile_ie.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::ConvertTileToIETile::convert_tile() {
+ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() {
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
auto tile = std::make_shared<ngraph::opset1::Tile>(data, shp);
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto tile = std::dynamic_pointer_cast<ngraph::opset1::Tile> (m.get_match_root());
if (!tile) {
return false;
return true;
};
- auto m = std::make_shared<ngraph::pattern::Matcher>(tile, "CPUFusion.ConvertTileToIETiles");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ auto m = std::make_shared<ngraph::pattern::Matcher>(tile, "ConvertTileToIETiles");
+ this->register_matcher(m, callback);
}
#include <ngraph_ops/topk_ie.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertTopKToTopKIE::convert_topk_to_topk_ie() {
- auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::TopK>());
+ngraph::pass::ConvertTopKToTopKIEMatcher::ConvertTopKToTopKIEMatcher() {
+ auto topk = ngraph::pattern::wrap_type<opset1::TopK>();
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+ ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto topk = std::dynamic_pointer_cast<opset1::TopK>(m.get_match_root());
if (!topk || topk->input(1).get_partial_shape().rank().is_dynamic()) {
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(topk, "ConvertTopKToTopKIE");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/convert_opset1_to_legacy/fc_bias_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset1.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
+ auto fc = ngraph::pattern::wrap_type<op::FullyConnected>();
+ auto add = ngraph::pattern::wrap_type<opset1::Add>({fc, std::make_shared<pattern::op::Label>()});
+
+ ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
+ auto add = m.get_match_root();
+ auto add_input_0 = add->input(0).get_source_output().get_node_shared_ptr();
+ auto add_input_1 = add->input(1).get_source_output().get_node_shared_ptr();
+
+ auto m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_0);
+ auto m_bias = add_input_1;
+
+ if (m_fc == nullptr) {
+ m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_1);
+ m_bias = add_input_0;
+ }
+
+ if (auto bcast_m = std::dynamic_pointer_cast<opset1::Broadcast>(m_bias)) {
+ m_bias = bcast_m->input(0).get_source_output().get_node_shared_ptr();
+ }
+
+ if (!std::dynamic_pointer_cast<opset1::Constant>(m_bias)) {
+ return false;
+ }
+ Shape bias_shape(m_bias->get_shape());
+
+ if (m_fc->output(0).get_target_inputs().size() != 1) {
+ return false;
+ }
+
+ Shape output_shape(m_fc->get_shape());
+ size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<int64_t>());
+ if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
+ return false;
+ }
+
+ NodeVector new_ops;
+
+ auto new_bias = std::make_shared<opset1::Add>(m_fc->input(2).get_source_output(), m_bias);
+ new_ops.push_back(new_bias);
+ std::shared_ptr<Node> final_bias = new_bias;
+ if (new_bias->get_shape().size() >= 2) {
+ final_bias = std::make_shared<opset1::Reshape>(final_bias, opset1::Constant::create(element::i64, Shape{1}, {-1}), true);
+ new_ops.push_back(final_bias);
+ }
+
+ auto new_fc = std::make_shared<op::FullyConnected>(m_fc->input(0).get_source_output(),
+ m_fc->input(1).get_source_output(),
+ final_bias,
+ m_fc->get_shape());
+ new_ops.push_back(new_fc);
+
+ new_fc->set_friendly_name(add->get_friendly_name());
+ ngraph::copy_runtime_info({m_fc, add}, new_ops);
+ ngraph::replace_node(add, new_fc);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(add, "FullyConnectedBiasFusion");
+ this->register_matcher(m, callback);
+}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
#include "ngraph_ops/convolution_ie.hpp"
#include "transformations/utils/utils.hpp"
node->get_auto_pad());
}
-void ngraph::pass::Reshape1DOps::reshape_ops() {
- auto node = std::make_shared<pattern::op::Label>(element::f32, Shape{},
- [](const std::shared_ptr<Node> & node) {
- return std::dynamic_pointer_cast<op::ConvolutionIE>(node) ||
- std::dynamic_pointer_cast<opset1::MaxPool>(node) ||
- std::dynamic_pointer_cast<opset1::AvgPool>(node);});
-
- ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
+matcher_pass_callback get_callback() {
+ return [](pattern::Matcher& m) {
auto node = m.get_match_root();
if (!node || node->input(0).get_partial_shape().rank().get_length() != 3) {
return false;
node->output(0).replace(last);
return true;
};
+}
+
+ngraph::pass::Reshape1DConvolution::Reshape1DConvolution() {
+ auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>();
+ auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution");
+ this->register_matcher(m, get_callback());
+}
+
+ngraph::pass::Reshape1DAvgPool::Reshape1DAvgPool() {
+ auto pool = ngraph::pattern::wrap_type<opset1::AvgPool>();
+ auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool");
+ this->register_matcher(m, get_callback());
+}
- auto m = std::make_shared<ngraph::pattern::Matcher>(node, "Reshape1DOps");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ngraph::pass::Reshape1DMaxPool::Reshape1DMaxPool() {
+ auto pool = ngraph::pattern::wrap_type<opset1::MaxPool>();
+ auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool");
+ this->register_matcher(m, get_callback());
}
\ No newline at end of file
#include "ngraph_ops/fully_connected.hpp"
#include "transformations/utils/utils.hpp"
-void ngraph::pass::ReshapeFullyConnected::reshape_fully_connected() {
+ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {
auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1});
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1});
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected> (m.get_match_root());
- if (!fc || transformation_callback(fc)) {
+ if (!fc || m_transformation_callback(fc)) {
return false;
}
};
auto m = std::make_shared<ngraph::pattern::Matcher>(fc, "ReshapeFullyConnected");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph/pass/manager.hpp>
bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr<ngraph::Function> f) {
- ngraph::pass::Manager OpSet2ToOpSet1;
- std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
+ ngraph::pass::Manager manager;
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(OpSet2ToOpSet1.register_pass<NAMESPACE::NAME>());
-#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1_tbl.hpp>
-#undef NGRAPH_PASS
+ manager.register_pass<ngraph::pass::ConvertGELU>();
+ manager.register_pass<ngraph::pass::ConvertSpaceToBatch>();
+ manager.register_pass<ngraph::pass::ConvertBatchToSpace>();
- for (auto & t : transforms) {
- if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
- t_param->setCallback(transformation_callback);
- }
- }
- OpSet2ToOpSet1.run_passes(f);
+ manager.set_callback(m_transformation_callback);
+ manager.run_passes(f);
return true;
}
\ No newline at end of file
#include <ngraph/pass/manager.hpp>
bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr<ngraph::Function> f) {
- ngraph::pass::Manager OpSet3ToOpSet2;
- std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
+ ngraph::pass::Manager manager;
-#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(OpSet3ToOpSet2.register_pass<NAMESPACE::NAME>());
-#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp>
-#undef NGRAPH_PASS
+ manager.register_pass<ngraph::pass::ConvertBroadcast3>();
+ manager.register_pass<ngraph::pass::ConvertNMS3>();
+ manager.register_pass<ngraph::pass::ConvertShapeOf3>();
+ manager.register_pass<ngraph::pass::ConvertShuffleChannels3>();
+ manager.register_pass<ngraph::pass::ConvertTopK3>();
- for (auto & t : transforms) {
- if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
- t_param->setCallback(transformation_callback);
- }
- }
- OpSet3ToOpSet2.run_passes(f);
+ manager.set_callback(m_transformation_callback);
+ manager.run_passes(f);
return true;
}
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher &m) {
auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());
- if (!shuffle_channels || transformation_callback(shuffle_channels)) {
+ if (!shuffle_channels || m_transformation_callback(shuffle_channels)) {
return false;
}
if (shuffle_channels->input_value(0).get_partial_shape().rank().is_dynamic()) {
auto data = space_to_batch->input_value(0);
const auto& data_shape = data.get_shape();
- if (transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
+ if (m_transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
return false;
}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertSpaceToDepth::convert() {
- auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
- auto dts = std::make_shared<ngraph::opset1::SpaceToDepth>(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST);
+ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
+ auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
- if (!std_node || transformation_callback(std_node)) {
+ if (!std_node || m_transformation_callback(std_node)) {
return false;
}
};
auto m = std::make_shared<ngraph::pattern::Matcher>(dts, "ConvertSpaceToDepth");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
-void ngraph::pass::ConvertSubtract::convert_subtract() {
- auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
- auto sub = std::make_shared<ngraph::opset1::Subtract>(input0, input1);
+ngraph::pass::ConvertSubtract::ConvertSubtract() {
+ auto sub = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto sub = std::dynamic_pointer_cast<ngraph::opset1::Subtract> (m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, "ConvertSubtract");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
\ No newline at end of file
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
- if (!transformation_callback(depth_to_space)) {
+ if (!m_transformation_callback(depth_to_space)) {
return false;
}
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/lin_op_sequence_fusoin.hpp"
+#include "transformations/mul_add_squence_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+using namespace ngraph;
+
+template <class T>
+Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
+ auto eltwise = std::make_shared<T>(input0, input1);
+ OutputVector output(eltwise->get_output_size());
+ if (!eltwise->constant_fold(output, {input0, input1})) {
+ throw ngraph_error("Can not constant fold eltwise node");
+ }
+ if (output.size() != 1) {
+ throw ngraph_error("Eltwise constant fold has unexpected number of outputs: " + std::to_string(output.size()));
+ }
+ return output[0];
+}
+
+ngraph::pass::AddMultiplyFusion::AddMultiplyFusion() {
+ // Create Add->Multiply pattern where Add has exactly one consumer
+ auto m_data = ngraph::pattern::any_input();
+ auto m_add_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+ auto m_add = ngraph::pattern::wrap_type<opset3::Add>({m_data, m_add_constant}, pattern::consumers_count(1));
+ auto m_mul_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+ auto m_mul = ngraph::pattern::wrap_type<opset3::Multiply>({m_add, m_mul_constant});
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
+ auto & label_to_output = m.get_pattern_value_map();
+
+ auto mul = label_to_output[m_mul].get_node_shared_ptr();
+ auto add = label_to_output[m_add].get_node_shared_ptr();
+
+ Output<Node> input = label_to_output[m_data];
+ Output<Node> mul_const = label_to_output[m_mul_constant];
+ Output<Node> add_const = label_to_output[m_add_constant];
+
+ // Replace Add->Multiply with Multiply->Add
+ // As new Multiply can be fused with operation above it we add this Multiply
+ // to the list of operations that will be used in additional matching.
+ auto new_mul = register_new_node<opset3::Multiply>(input, mul_const);
+
+ // Add two constants using opset3::Add constant folding and create new Add operation
+ auto new_add = std::make_shared<opset3::Add>(new_mul, eltwise_fold<opset3::Multiply>(add_const, mul_const));
+
+ copy_runtime_info({add, mul}, {new_mul, new_add});
+ new_add->set_friendly_name(mul->get_friendly_name());
+ replace_node(mul, new_add);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(m_mul, "AddMultiplyFusion");
+ this->register_matcher(m, callback);
+}
+
+ngraph::pass::AddAddFusion::AddAddFusion() {
+ // Create Add->Add pattern where first Add has exactly one consumer
+ auto m_data = ngraph::pattern::any_input();
+ auto m_add1_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+ auto m_add1 = ngraph::pattern::wrap_type<opset3::Add>({m_data, m_add1_constant}, pattern::consumers_count(1));
+ auto m_add2_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+ auto m_add2 = ngraph::pattern::wrap_type<opset3::Add>({m_add1, m_add2_constant});
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
+ auto & label_to_output = m.get_pattern_value_map();
+
+ auto add1 = label_to_output[m_add1].get_node_shared_ptr();
+ auto add2 = label_to_output[m_add2].get_node_shared_ptr();
+
+ Output<Node> input = label_to_output[m_data];
+ Output<Node> add1_const = label_to_output[m_add1_constant];
+ Output<Node> add2_const = label_to_output[m_add2_constant];
+
+ // Replace Add->Add with single Add
+ // Add operation will be added to the list of ops requested for pattern matching
+ auto new_add = register_new_node<opset3::Add>(input, eltwise_fold<opset3::Add>(add1_const, add2_const));
+
+ copy_runtime_info({add1, add2}, new_add);
+ new_add->set_friendly_name(add2->get_friendly_name());
+ replace_node(add2, new_add);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(m_add2, "AddAddFusion");
+ this->register_matcher(m, callback);
+}
+
+ngraph::pass::MultiplyMultiplyFusion::MultiplyMultiplyFusion() {
+ // Create Multiply->Multiply pattern where first Multiply has exactly one consumer
+ auto m_data = ngraph::pattern::any_input();
+ auto m_mul1_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+ auto m_mul1 = ngraph::pattern::wrap_type<opset3::Multiply>({m_data, m_mul1_constant}, pattern::consumers_count(1));
+ auto m_mul2_constant = ngraph::pattern::wrap_type<opset3::Constant>();
+ auto m_mul2 = ngraph::pattern::wrap_type<ngraph::opset3::Multiply>({m_mul1, m_mul2_constant});
+
+ ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher & m) -> bool {
+ auto & label_to_output = m.get_pattern_value_map();
+
+ auto mul1 = label_to_output[m_mul1].get_node_shared_ptr();
+ auto mul2 = label_to_output[m_mul2].get_node_shared_ptr();
+
+ Output<Node> input = label_to_output[m_data];
+ Output<Node> mul1_const = label_to_output[m_mul1_constant];
+ Output<Node> mul2_const = label_to_output[m_mul2_constant];
+
+ // Replace Multiply->Multiply with single Multiply
+ // Multiply operation will be added to the list of ops requested for pattern matching
+ auto new_mul = register_new_node<opset3::Multiply>(input, eltwise_fold<opset3::Multiply>(mul1_const, mul2_const));
+
+ copy_runtime_info({mul1, mul2}, new_mul);
+ new_mul->set_friendly_name(mul2->get_friendly_name());
+ replace_node(mul2, new_mul);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(m_mul2, "MultiplyMultiplyFusion");
+ this->register_matcher(m, callback);
+}
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
-void ngraph::pass::PullTransposeThroughFQUp::pull_transpose_through_fq() {
+ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
auto data1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto data2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto data3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "PullTransposeThroughFQUp");
- this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
+ this->register_matcher(m, callback);
}
// Disable shape inference (WA for generic operations)
ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
- ngraph::pass::ConvertOpSet3ToOpSet2(transformationsPredicate).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet2ToOpSet1(transformationsPredicate).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet1ToLegacy(transformationsPredicate).run_on_function(nGraphFunc);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+ manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+ manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+ manager.set_callback(transformationsPredicate);
+ manager.run_passes(nGraphFunc);
vpu::MergeSubsequentDSROperations().run_on_function(nGraphFunc);
#include <ngraph_ops/gru_cell_ie.hpp>
#include <ngraph_ops/rnn_cell_ie.hpp>
#include <ngraph_ops/lstm_cell_ie.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
cell->set_friendly_name("test_cell");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertCellsToCellsIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertGRUCellMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
cell->set_friendly_name("test_cell");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertCellsToCellsIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertRNNCellMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
cell->set_friendly_name("test_cell");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t, C_t});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertCellsToCellsIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertLSTMCellMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/convert_divide.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertDivide().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertDivide>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertDivide().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertDivide>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph_ops/gather_ie.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertGatherToGatherIE().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertGatherToGatherIE().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertGatherToGatherIE().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<Function>(NodeVector{gather}, ParameterVector{input, indices});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertGatherToGatherIE().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertGatherToGatherIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/convert_opset1_to_legacy/reshape_fully_connected.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
- ngraph::pass::ReshapeFullyConnected().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.register_pass<ngraph::pass::ReshapeFullyConnected>();
+ m.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
}
{
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMatMulToFCorGemm().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
+ m.register_pass<ngraph::pass::ReshapeFullyConnected>();
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
if (auto fc_op = std::dynamic_pointer_cast<const ngraph::op::FullyConnected>(node)) {
}
return false;
};
- auto p = ngraph::pass::ReshapeFullyConnected();
- p.setCallback(callback);
- p.run_on_function(f);
+
+ m.set_callback(callback);
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
const auto &orig_shape = f->get_output_partial_shape(0);
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertNMS4ToLegacy().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertNMS4ToLegacyMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertNMS4ToLegacy().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertNMS4ToLegacyMatcher>();
+ manager.run_passes(f);
f->validate_nodes_and_infer_types();
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertNMS4ToLegacy().run_on_function(f);
+ pass::Manager manager;
+ manager.register_pass<pass::InitNodeInfo>();
+ manager.register_pass<pass::ConvertNMS4ToLegacyMatcher>();
+ manager.run_passes(f);
+
f->validate_nodes_and_infer_types();
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/utils/utils.hpp>
#include <ngraph_ops/nms_ie.hpp>
#include <ngraph/pass/constant_folding.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
const auto & orig_shape = f->get_output_partial_shape(0);
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertNMSToNMSIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertNMSToNMSIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+ manager.run_passes(f);
f->validate_nodes_and_infer_types();
ASSERT_NO_THROW(check_rt_info(f));
}
f = std::make_shared<Function>(NodeVector{nms}, ParameterVector{boxes, scores});
- pass::InitNodeInfo().run_on_function(f);
- pass::ConvertNMSToNMSIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertNMSToNMSIEMatcher>();
+ manager.run_passes(f);
f->validate_nodes_and_infer_types();
ASSERT_NO_THROW(check_rt_info(f));
}
#include <ngraph/op/reshape.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
sslice->set_friendly_name("strided_slice");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
sslice->set_friendly_name("strided_slice");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
sslice->set_friendly_name("strided_slice");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
sslice->set_friendly_name("strided_slice");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sslice}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertStridedSliceToCrop().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/utils/utils.hpp>
#include <ngraph_ops/topk_ie.hpp>
#include <ngraph/pass/constant_folding.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
- ASSERT_NO_THROW(check_rt_info(f));
- ngraph::pass::ConstantFolding().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+ manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+ check_rt_info(f);
+ });
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+ ASSERT_NO_THROW(manager.run_passes(f));
ASSERT_TRUE(f->get_output_partial_shape(0).is_static()) << "Shape " << f->get_output_partial_shape(0) << " should be static";
}
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input, k});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::ConvertTopKToTopKIEMatcher>();
+ manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
};
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
- depth_to_space_transform.setCallback(callback);
+ depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
};
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
- depth_to_space_transform.setCallback(callback);
+ depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
// transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
- depth_to_space_transform.setCallback(callback);
+ depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
// transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
- depth_to_space_transform.setCallback(callback);
+ depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::FullyConnectedBiasFusion().run_on_function(f);
- ASSERT_NO_THROW(check_rt_info(f));
- ngraph::pass::ConstantFolding().run_on_function(f);
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::FullyConnectedBiasFusion>();
+ manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+ check_rt_info(f);
+ });
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+ ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::FullyConnectedBiasFusion().run_on_function(f);
- ASSERT_NO_THROW(check_rt_info(f));
- ngraph::pass::ConstantFolding().run_on_function(f);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::FullyConnectedBiasFusion>();
+ manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+ check_rt_info(f);
+ });
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+ ASSERT_NO_THROW(manager.run_passes(f));
}
{
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include "common_test_utils/test_common.hpp"
+#include <string>
+#include <sstream>
+#include <fstream>
+#include <memory>
+#include <queue>
+#include <map>
+
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/lin_op_sequence_fusoin.hpp>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/visualize_tree.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+using namespace ngraph;
+
+TEST(TransformationTests, MulAddMulAddFusion) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+ auto mul2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+ auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {4});
+ auto add2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {5});
+
+ auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+ auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+ auto mul2 = std::make_shared<opset3::Multiply>(add1, mul2_const);
+ auto add2 = std::make_shared<opset3::Add>(mul2, add2_const);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add2}, ngraph::ParameterVector{input});
+ }
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {6});
+ auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {17});
+
+ auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+ auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add1}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, MulMulMulFusion) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+ auto mul2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+ auto mul3_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
+
+ auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+ auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
+ auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul2}, ngraph::ParameterVector{input});
+ }
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {12});
+
+ auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul1}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, AddAddAddFusion) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+ auto add2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+ auto add3_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
+
+ auto add1 = std::make_shared<opset3::Add>(input, add1_const);
+ auto add2 = std::make_shared<opset3::Add>(add1, add2_const);
+ auto add3 = std::make_shared<opset3::Add>(add2, add3_const);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add3}, ngraph::ParameterVector{input});
+ }
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {8});
+
+ auto add1 = std::make_shared<opset3::Add>(input, add1_const);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add1}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, MulAddAddMulFusion) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {2});
+ auto mul2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {3});
+ auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {4});
+ auto add2_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {5});
+
+ auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+ auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+ auto add2 = std::make_shared<opset3::Add>(add1, add2_const);
+ auto mul2 = std::make_shared<opset3::Multiply>(add2, mul2_const);
+
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul2}, ngraph::ParameterVector{input});
+ }
+
+ pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+
+ {
+ auto input = std::make_shared<opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 3072});
+ auto mul1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {10});
+ auto add1_const = opset3::Constant::create(ngraph::element::f32, ngraph::Shape{128, 1}, {40});
+
+ auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
+ auto add1 = std::make_shared<opset3::Add>(mul1, add1_const);
+
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add1}, ngraph::ParameterVector{input});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file
#include <transformations/convert_depth_to_space.hpp>
#include <transformations/convert_space_to_depth.hpp>
#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
{
auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertDepthToSpace().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertDepthToSpace>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, 2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertDepthToSpace().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertDepthToSpace>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, 2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertSpaceToDepth().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST, 2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertSpaceToDepth().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
#include <transformations/pull_transpose_through_fq.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::PullTransposeThroughFQUp().run_on_function(f);
- ASSERT_NO_THROW(check_rt_info(f));
- ngraph::pass::ConstantFolding().run_on_function(f);
+
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
+ manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
+ check_rt_info(f);
+ });
+ manager.register_pass<ngraph::pass::ConstantFolding>();
+ ASSERT_NO_THROW(manager.run_passes(f));
}
std::vector<size_t> ref_shape{1, 3, 1};
for (auto op : f->get_ops()) {
#include <transformations/convert_mod.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/init_node_info.hpp>
+#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
auto mod = std::make_shared<ngraph::op::v1::Mod>(data1, data2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mod}, ngraph::ParameterVector{});
- ngraph::pass::InitNodeInfo().run_on_function(f);
- ngraph::pass::ConvertMod().run_on_function(f);
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertMod>();
+ m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
ASSERT_EQ(f->get_ops().size(), 12);
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/pass/manager.hpp>
#include "ngraph_functions/pass/convert_prc.hpp"
#include "common_test_utils/common_utils.hpp"
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
// Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
- ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::CommonOptimizations>();
+ manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+ manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+ manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+ manager.set_callback(transformations_callback);
+ manager.run_passes(nGraphFunc);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
}
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/fused/gelu.hpp>
+#include <ngraph/pass/manager.hpp>
#include "ngraph_functions/pass/convert_prc.hpp"
#include "common_test_utils/common_utils.hpp"
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
// Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
- ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
- ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::CommonOptimizations>();
+ manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
+ manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();
+ manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
+
+ manager.set_callback(transformations_callback);
+ manager.run_passes(nGraphFunc);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
}
#include <ngraph/function.hpp>
#include <ngraph/dimension.hpp>
+#include <ngraph/pass/pass.hpp>
#include "test_common.hpp"
void check_rt_info(const std::shared_ptr<ngraph::Function> & f);
-void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name);
\ No newline at end of file
+void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name);
+
+namespace ngraph {
+namespace pass {
+
+class InjectionPass;
+
+} // namespace pass
+} // namespace ngraph
+
+class ngraph::pass::InjectionPass : public ngraph::pass::FunctionPass {
+public:
+ using injection_callback = std::function<void(std::shared_ptr<ngraph::Function>)>;
+
+ explicit InjectionPass(injection_callback callback) : FunctionPass(), m_callback(std::move(callback)) {}
+
+ bool run_on_function(std::shared_ptr<ngraph::Function> f) override {
+ m_callback(f);
+ return false;
+ }
+
+private:
+ injection_callback m_callback;
+};
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
LINK_LIBRARIES
unitTestUtils
+ inference_engine_transformations
${OpenCV_LIBRARIES}
ADD_CPPLINT
DEPENDENCIES
pattern/op/skip.hpp
pattern/op/true.cpp
pattern/op/true.hpp
+ pattern/op/wrap_type.hpp
+ pattern/op/wrap_type.cpp
provenance.cpp
provenance.hpp
rank.hpp
bool Node::match_node(pattern::Matcher* matcher, const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
- return graph_value.get_node_shared_ptr()->get_type_info() == get_type_info() &&
- matcher->match_arguments(this, graph_value.get_node_shared_ptr());
+ if (graph_value.get_node_shared_ptr()->get_type_info() == get_type_info() &&
+ matcher->match_arguments(this, graph_value.get_node_shared_ptr()))
+ {
+ auto& pattern_map = matcher->get_pattern_value_map();
+ pattern_map[shared_from_this()] = graph_value;
+ return true;
+ }
+ return false;
}
// default implementation for the node to check if it contains partial shape
void ngraph::pass::ConstantFolding::construct_constant_default()
{
- add_handler("Constant folding defaults",
- [](const std::shared_ptr<Node>& node) -> bool {
- OutputVector replacements(node->get_output_size());
- if (!node->constant_fold(replacements, node->input_values()))
- {
- return false;
- }
- NGRAPH_CHECK(
- replacements.size() == node->get_output_size(),
- "constant_fold_default returned incorrect number of replacements for ",
- node);
- bool result{false};
- for (size_t i = 0; i < replacements.size(); ++i)
- {
- auto node_output = node->output(i);
- auto replacement = replacements.at(i);
- if (replacement.get_node_shared_ptr() && (node_output != replacement))
- {
- node_output.replace(replacement);
- result = true;
- }
- }
- return result;
- },
- PassProperty::CHANGE_DYNAMIC_STATE);
+ m_matchers.push_back(std::make_shared<MatcherPass>(
+ "Constant folding defaults",
+ nullptr,
+ [](const std::shared_ptr<Node>& node) -> bool {
+ OutputVector replacements(node->get_output_size());
+ if (!node->constant_fold(replacements, node->input_values()))
+ {
+ return false;
+ }
+ NGRAPH_CHECK(replacements.size() == node->get_output_size(),
+ "constant_fold_default returned incorrect number of replacements for ",
+ node);
+ bool result{false};
+ for (size_t i = 0; i < replacements.size(); ++i)
+ {
+ auto node_output = node->output(i);
+ auto replacement = replacements.at(i);
+ if (replacement.get_node_shared_ptr() && (node_output != replacement))
+ {
+ node_output.replace(replacement);
+ result = true;
+ }
+ }
+ return result;
+ },
+ PassProperty::CHANGE_DYNAMIC_STATE));
}
#include "constant_folding.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/slice.hpp"
#include "ngraph/op/split.hpp"
+#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
+template <class T>
+shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
+ shared_ptr<op::Slice> slice)
+{
+ const Shape& out_shape = slice->get_shape();
+ runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
+ T* data_ptr = buffer.get_ptr<T>();
+
+ runtime::reference::slice<T>(constant->get_data_ptr<T>(),
+ data_ptr,
+ constant->get_shape(),
+ slice->get_lower_bounds(),
+ slice->get_upper_bounds(),
+ slice->get_strides(),
+ out_shape);
+
+ return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
+}
+
void pass::ConstantFolding::construct_constant_split()
{
auto data_label = make_shared<pattern::op::Label>(
output.replace(slices[index++]->output(0));
}
split->outputs().clear();
- construct_constant_slice();
+
+ for (auto& slice : slices)
+ {
+ auto const_data = std::dynamic_pointer_cast<op::Constant>(
+ slice->input_value(0).get_node_shared_ptr());
+ auto slice_node = std::dynamic_pointer_cast<op::Slice>(slice);
+ if (!const_data || !slice_node)
+ continue;
+
+ std::shared_ptr<op::Constant> replacement;
+ switch (slice->get_output_element_type(0))
+ {
+ case element::Type_t::undefined:
+ NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_slice");
+ break;
+ case element::Type_t::dynamic:
+ NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
+ break;
+ case element::Type_t::u1:
+ NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_slice");
+ break;
+ case element::Type_t::boolean:
+ replacement = fold_constant_slice<char>(const_data, slice_node);
+ break;
+ case element::Type_t::bf16:
+ replacement = fold_constant_slice<bfloat16>(const_data, slice_node);
+ break;
+ case element::Type_t::f16:
+ replacement = fold_constant_slice<float16>(const_data, slice_node);
+ break;
+ case element::Type_t::f32:
+ replacement = fold_constant_slice<float>(const_data, slice_node);
+ break;
+ case element::Type_t::f64:
+ replacement = fold_constant_slice<double>(const_data, slice_node);
+ break;
+ case element::Type_t::i8:
+ replacement = fold_constant_slice<int8_t>(const_data, slice_node);
+ break;
+ case element::Type_t::i16:
+ replacement = fold_constant_slice<int16_t>(const_data, slice_node);
+ break;
+ case element::Type_t::i32:
+ replacement = fold_constant_slice<int32_t>(const_data, slice_node);
+ break;
+ case element::Type_t::i64:
+ replacement = fold_constant_slice<int64_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u8:
+ replacement = fold_constant_slice<uint8_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u16:
+ replacement = fold_constant_slice<uint16_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u32:
+ replacement = fold_constant_slice<uint32_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u64:
+ replacement = fold_constant_slice<uint64_t>(const_data, slice_node);
+ break;
+ }
+ replace_node(slice_node, replacement);
+ }
return true;
};
#include "constant_folding.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
+#include "ngraph/op/slice.hpp"
#include "ngraph/op/variadic_split.hpp"
+#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
+template <class T>
+shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
+ shared_ptr<op::Slice> slice)
+{
+ const Shape& out_shape = slice->get_shape();
+ runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
+ T* data_ptr = buffer.get_ptr<T>();
+
+ runtime::reference::slice<T>(constant->get_data_ptr<T>(),
+ data_ptr,
+ constant->get_shape(),
+ slice->get_lower_bounds(),
+ slice->get_upper_bounds(),
+ slice->get_strides(),
+ out_shape);
+
+ return make_shared<op::Constant>(constant->get_element_type(), out_shape, data_ptr);
+}
+
void pass::ConstantFolding::construct_constant_variadic_split()
{
auto data_label = make_shared<pattern::op::Label>(
}
}
variadic_split->outputs().clear();
- construct_constant_slice();
+
+ for (auto& slice : slices)
+ {
+ auto const_data = std::dynamic_pointer_cast<op::Constant>(
+ slice->input_value(0).get_node_shared_ptr());
+ auto slice_node = std::dynamic_pointer_cast<op::Slice>(slice);
+ if (!const_data || !slice_node)
+ continue;
+
+ std::shared_ptr<op::Constant> replacement;
+ switch (slice->get_output_element_type(0))
+ {
+ case element::Type_t::undefined:
+ NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_slice");
+ break;
+ case element::Type_t::dynamic:
+ NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
+ break;
+ case element::Type_t::u1:
+ NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_slice");
+ break;
+ case element::Type_t::boolean:
+ replacement = fold_constant_slice<char>(const_data, slice_node);
+ break;
+ case element::Type_t::bf16:
+ replacement = fold_constant_slice<bfloat16>(const_data, slice_node);
+ break;
+ case element::Type_t::f16:
+ replacement = fold_constant_slice<float16>(const_data, slice_node);
+ break;
+ case element::Type_t::f32:
+ replacement = fold_constant_slice<float>(const_data, slice_node);
+ break;
+ case element::Type_t::f64:
+ replacement = fold_constant_slice<double>(const_data, slice_node);
+ break;
+ case element::Type_t::i8:
+ replacement = fold_constant_slice<int8_t>(const_data, slice_node);
+ break;
+ case element::Type_t::i16:
+ replacement = fold_constant_slice<int16_t>(const_data, slice_node);
+ break;
+ case element::Type_t::i32:
+ replacement = fold_constant_slice<int32_t>(const_data, slice_node);
+ break;
+ case element::Type_t::i64:
+ replacement = fold_constant_slice<int64_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u8:
+ replacement = fold_constant_slice<uint8_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u16:
+ replacement = fold_constant_slice<uint16_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u32:
+ replacement = fold_constant_slice<uint32_t>(const_data, slice_node);
+ break;
+ case element::Type_t::u64:
+ replacement = fold_constant_slice<uint64_t>(const_data, slice_node);
+ break;
+ }
+ replace_node(slice_node, replacement);
+ }
return true;
};
//*****************************************************************************
#include <algorithm>
+#include <deque>
#include <iostream>
+#include <pattern/op/wrap_type.hpp>
#include <regex>
#include <unordered_set>
#include <vector>
//
// The topological order would be : `Constant1`, `Abs2`, `Neg3`, `Add4`, `Result5`
// Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3)
-// Next, GraphRewrite will invoke matchers registered in an order registered in a c-tor
-// i.e. if a c-tor calls `construct_m1()`; `construct_m2()`; `construct_m3()`;
-// Matchers will be called as follows: `m1`, `m2`, `m3`
+// Next, GraphRewrite will invoke matchers passes registered in add_matcher order.
+// For example:
+// ngraph::pass::GraphRewrite pass;
+// pass.add_matcher<m1>();
+// pass.add_matcher<m2>();
+// pass.add_matcher<m3>();
+// Matcher passes will be called as follows: `m1`, `m2`, `m3`
// Matchers should only replace nodes in the graph that come before the current root
// node in the topological order. For example, if Matcher matches Neg3, it should only
// replace nodes `Abs2` and `Constant1` if needed
// and `m2` folds `Neg3(Constant1)` when `m3` is called on `Add4` it will discover that
// both `Abs2` and `Neg3` were already replaced by constants, so `Add4` will also be folded into
// one.
-// If any Matcher succeeds the rest of the matchers will **not** be called.
+// If any matcher passes succeeds the rest of the matchers will **not** be called.
// E.g. if `m1` succeeds and replaces `Abs2` with a new constant, nor `m2` or `m3` will be called
// However, sometimes, you will need more than one fusion occur on the same node.
-// In this case, you should be able to request another pass of GraphRewrite.
-// To request another pass, you will need to register fusions in a callback:
-// i.e. you will need to pass `this` into a callback and then call `this->construct_X`
-// This will schedule another pass of GraphRewrite with the following fusion.
-// This approach should only be used if you are either:
-// a) need more than one fusion occur on the same node
-// b) you are modifying nodes after the current node in the topological order
-// c) there's no linear order of fusions which will give
-// the correct final fusion. i.e. the same fusion needs to occur before and after some other
-// fusion
+// In this case, you need to register nodes in MatcherPass manually using register_new_node method.
+// GraphRewrite will automatically add this nodes in the beginning of execution queue.
+// If MatcherPass register more than one node make sure that this nodes are registered in
+// topological order.
bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{
bool rewritten = false;
- const size_t NUM_TRIES = 10;
- size_t tries = NUM_TRIES;
- vector<MatchClosure> original_matchers{m_matchers};
- // This check is very expensive and is only needed for experimental features, so we will hide
- // it behind an environment variable for now. TODO: Find a less expensive way to handle this.
- static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
- bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
- do
+
+ // Initialize execution queue with nodes in topological order
+ deque<std::shared_ptr<Node>> nodes_to_run;
+ for (auto& node : f->get_ordered_ops())
{
- rewritten = false;
- // m_matchers may contain newly constructed matchers for matchers
- // that need multiple passes. See comments above.
- vector<MatchClosure> matchers_to_run{m_matchers};
- m_matchers.clear();
- for (auto node : f->get_ordered_ops())
+ nodes_to_run.emplace_back(node);
+ }
+
+ // Check that all Matchers in MatcherPasses has type bases root node
+ bool all_roots_has_type = true;
+ std::unordered_map<NodeTypeInfo, std::vector<std::shared_ptr<MatcherPass>>> type_to_matcher;
+ for (auto& m : m_matchers)
+ {
+ auto matcher = m->get_matcher();
+ if (!matcher)
+ {
+ all_roots_has_type = false;
+ break;
+ }
+
+ auto root = matcher->get_pattern_value().get_node_shared_ptr();
+ // pattern::op::AnyOutput operation automatically appends for multi output operations inside
+ // Matcher and to gen actual root node we need to take it's parent.
+ if (auto any_type = dynamic_pointer_cast<pattern::op::AnyOutput>(root))
{
- if (m_enable_shape_inference)
+ root = any_type->input_value(0).get_node_shared_ptr();
+ }
+
+ // if root is an operation from opset or has pattern::op::WrapType type then we can extract
+ // it's type
+ // and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
+ // and default algorithm is used.
+ NodeTypeInfo root_type_info = root->get_type_info();
+ if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root))
+ {
+ if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p))
{
- node->revalidate_and_infer_types();
+ root_type_info = any_type->get_wrapped_type();
}
- for (auto& closure : matchers_to_run)
+ else
{
- if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
- {
- NGRAPH_DEBUG << "matcher callback requires static shape but the "
- "function is dynamic, skipping this "
- "optimization till the shapes are fully "
- "materialized";
- continue;
- }
- if (closure.handler(node))
- {
- rewritten = true;
- // If call back may change function's is_dynamic state, we need to
- // update the cached value.
- if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
- {
- is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
- }
- break;
- }
+ all_roots_has_type = false;
+ break;
}
}
-
- } while (rewritten && m_matchers.size() > 0 && tries--);
-
- m_matchers.assign(original_matchers.begin(), original_matchers.end());
- return (NUM_TRIES - tries) > 1; // this means a graph was transformed
-}
-
-static vector<regex> initialize_fusion_regexes()
-{
- static const string nsf = getenv_string("NGRAPH_DISABLED_FUSIONS");
- vector<regex> regexes;
- if (!nsf.empty())
- {
- const auto sregexes = split(nsf, ';');
-
- transform(sregexes.begin(),
- sregexes.end(),
- back_inserter(regexes),
- [](const string& c) -> regex { return regex(c); });
+ type_to_matcher[root_type_info].push_back(m);
}
- return regexes;
-}
-bool pass::GraphRewriteBase::is_enabled(const std::string& name) const
-{
- // note, regexes are static to avoid re-initialization
- static const auto regexes = initialize_fusion_regexes();
-
- for (const auto& regex : regexes)
- {
- if (regex_match(name, regex))
+ // This lambda preforms execution of particular MatcherPass on given node.
+ // It automatically handles nodes registered by MatcherPass during transformation and set
+ // transformation callback.
+ auto run_matcher_pass = [&](std::shared_ptr<MatcherPass> m_pass,
+ std::shared_ptr<Node> node) -> bool {
+ // Keep this property check for backward compatibility. In future transformation property
+ // will be deprecated and removed.
+ if (m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && f->is_dynamic())
{
- NGRAPH_DEBUG << "Disabling matcher " << name;
+ NGRAPH_DEBUG << "matcher callback requires static shape but the "
+ "function is dynamic, skipping this "
+ "optimization till the shapes are fully "
+ "materialized";
return false;
}
- }
- return true;
-}
+ if (!m_has_default_callback)
+ {
+ m_pass->set_callback(m_transformation_callback);
+ }
-void pass::GraphRewriteBase::add_handler(const std::string& name,
- function<bool(const std::shared_ptr<Node>&)> handler,
- const PassPropertyMask& property)
-{
- if (is_enabled(name))
+ // Apply MatcherPass. In case if it returns true no other MatcherPasses will apply
+ // to this node
+ bool status = m_pass->apply(node);
+
+ // In case if MatcherPass registered nodes they will be added to the beginning of execution
+ // queue
+ const auto& new_nodes = m_pass->get_new_nodes();
+ if (!new_nodes.empty())
+ {
+ // Need to push nodes in reverse order as we expect that nodes in new_nodes
+ // vector are in topological order
+ for (auto it = new_nodes.rbegin(); it != new_nodes.rend(); it++)
+ {
+ nodes_to_run.emplace_front(*it);
+ }
+ m_pass->clear_new_nodes();
+ }
+ return status;
+ };
+
+ while (!nodes_to_run.empty())
{
- m_matchers.push_back({name, handler, property});
- // If any matcher call back may change dynamic state, we need to
- // update the pass property.
- if (property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
+ auto node = nodes_to_run.front();
+ nodes_to_run.pop_front();
+ // Temporary keep this GraphRewrite property for backward compatibility
+ if (m_enable_shape_inference)
+ {
+ node->revalidate_and_infer_types();
+ }
+ // If all Matchers in MatcherPasses has type based root node then we apply efficient
+ // algorithm for finding matchers
+ if (all_roots_has_type)
{
- set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
+ auto node_type_info = node->get_type_info();
+ if (type_to_matcher.count(node_type_info))
+ {
+ for (auto& m_pass : type_to_matcher[node_type_info])
+ {
+ if (run_matcher_pass(m_pass, node))
+ {
+ rewritten = true;
+ break;
+ }
+ }
+ }
+ }
+ // Otherwise we use default algorithm that iterates over all registered matcher passes
+ else
+ {
+ for (auto& m_pass : m_matchers)
+ {
+ if (run_matcher_pass(m_pass, node))
+ {
+ rewritten = true;
+ break;
+ }
+ }
}
}
+ return rewritten;
}
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,
const PassPropertyMask& property)
{
- add_handler(m->get_name(),
- [m, callback](const std::shared_ptr<Node>& node) -> bool {
- NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
- if (m->match(node->output(0)))
- {
- NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
- return callback(*m.get());
- }
- return false;
- },
- property);
+ m_matchers.push_back(std::make_shared<MatcherPass>(
+ m->get_name(),
+ m,
+ [m, callback](const std::shared_ptr<Node>& node) -> bool {
+ NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
+ if (m->match(node->output(0)))
+ {
+ NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
+ bool status = callback(*m.get());
+ // explicitly clear Matcher state because it holds pointers to matched nodes
+ m->clear_state();
+ return status;
+ }
+ m->clear_state();
+ return false;
+ },
+ property));
}
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property)
{
- add_handler("Reurrent matcher",
- [m, callback](const std::shared_ptr<Node>& node) {
- NGRAPH_DEBUG << "Running recurrent matcher on " << node;
- if (m->match(node->output(0)))
- {
- NGRAPH_DEBUG << "Recurrent matcher matched " << m.get();
- return callback(*m.get());
- }
- return false;
- },
- property);
+ m_matchers.push_back(std::make_shared<MatcherPass>(
+ "Recurrent matcher",
+ nullptr,
+ [m, callback](const std::shared_ptr<Node>& node) {
+ NGRAPH_DEBUG << "Running recurrent matcher on " << node;
+ if (m->match(node->output(0)))
+ {
+ NGRAPH_DEBUG << "Recurrent matcher matched " << m.get();
+ return callback(*m.get());
+ }
+ return false;
+ },
+ property));
}
void pass::RecurrentGraphRewrite::add_matcher(
bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
for (auto node : f->get_ops())
{
- for (auto& closure : m_matchers)
+ for (auto& m_pass : m_matchers)
{
- if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
+ if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE))
{
NGRAPH_DEBUG << "matcher callback requires static shape but the "
"function is dynamic, skipping this "
"materialized";
continue;
}
- if (closure.handler(node))
+ if (m_pass->apply(node))
{
// If call back may change function's is_dynamic state, we need to
// update the cached value.
- if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
+ if (m_pass->get_property(PassProperty::CHANGE_DYNAMIC_STATE))
{
is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
}
} while (changed && i < m_num_iters);
return changed;
}
+
+void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::pattern::Matcher>& m,
+ const ngraph::graph_rewrite_callback& callback,
+ const PassPropertyMask& property)
+{
+ set_name(m->get_name());
+ set_property(property, true);
+ m_matcher = m;
+ m_handler = [m, callback](const std::shared_ptr<Node>& node) -> bool {
+ if (m->match(node->output(0)))
+ {
+ NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
+ bool status = callback(*m.get());
+ // explicitly clear Matcher state because it holds pointers to matched nodes
+ m->clear_state();
+ return status;
+ }
+ m->clear_state();
+ return false;
+ };
+}
+
+bool ngraph::pass::MatcherPass::apply(std::shared_ptr<ngraph::Node> node)
+{
+ m_new_nodes.clear();
+ return m_handler(node);
+}
\ No newline at end of file
{
namespace pass
{
- class GraphRewriteBase;
class GraphRewrite;
class RecurrentGraphRewrite;
+ class MatcherPass;
}
+ using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback =
std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
+ using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
}
-class NGRAPH_API ngraph::pass::GraphRewriteBase : public ngraph::pass::FunctionPass
+/// \brief MatcherPass is a basic block for pattern based transformations. It describes pattern and
+/// action that is applied if pattern is matched.
+///
+/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented and
+/// finally registered by using \sa register_matcher. MatcherPass can be executed on node within
+/// \sa apply method. To run matcher pass on Function use GraphRewrite.
+/// In addition MatcherPass provides a way for adding new operations into GraphRewrite execution
+/// queue. That means that operations that were created inside transformation callback can be added
+/// for matching. To register node use \sa register_new_node method. GraphRewrite automatically
+/// takes registered nodes and put them to execution queue. If multiple nodes were register make
+/// sure that they were registered in topological order.
+/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
+/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
+/// efficient.
+
+class NGRAPH_API ngraph::pass::MatcherPass : public ngraph::pass::PassBase
{
public:
- /// \brief Add an arbitrary handler for nodes
- /// \param name The name of the handler
- /// \param handler Function responsible for deciding if the graph should be changed and making
- /// the changes. Returns true if changes are made.
- void add_handler(const std::string& name,
- std::function<bool(const std::shared_ptr<Node>& node)> handler,
- const PassPropertyMask& property);
-
-protected:
- GraphRewriteBase()
- : FunctionPass()
+ MatcherPass() = default;
+
+ MatcherPass(const MatcherPass&) = delete;
+ MatcherPass& operator=(const MatcherPass&) = delete;
+
+ explicit MatcherPass(const std::string& name,
+ const std::shared_ptr<pattern::Matcher>& m,
+ const handler_callback& handler,
+ const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
+ : PassBase()
+ , m_handler(handler)
+ , m_matcher(m)
{
- // Being explicit:
- // Setting REQUIRE_STATIC_SHAPE to false because we will check if each
- // callback needs static shape during run_on_function().
- set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
+ set_name(name);
+ set_property(property, true);
}
- bool is_enabled(const std::string& name) const;
+ bool apply(std::shared_ptr<ngraph::Node> node);
- struct MatchClosure
+ template <typename T, class... Args>
+ std::shared_ptr<T> register_new_node(Args&&... args)
{
- std::string name;
- std::function<bool(const std::shared_ptr<Node>& node)> handler;
- PassPropertyMask property;
- };
- std::vector<MatchClosure> m_matchers;
+ auto node = std::make_shared<T>(std::forward<Args>(args)...);
+ m_new_nodes.push_back(node);
+ return node;
+ }
+
+ const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() { return m_new_nodes; }
+ void clear_new_nodes() { m_new_nodes.clear(); }
+ std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
+protected:
+ void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
+ const ngraph::graph_rewrite_callback& callback,
+ const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
+
+private:
+ handler_callback m_handler;
+ std::shared_ptr<pattern::Matcher> m_matcher;
+ std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
};
-/// \brief GraphRewrite (in tandem with \sa Matcher) performs transformations on specified patterns
+/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function in
+/// efficient way
///
-/// Graph rewrite pass essentially allows pass users to rewrite parts of the
-/// input graph in any way they want. Fusion is one example of graph rewrite that
-/// fuses multiple ops together. At a high-level users of the pass need to
-/// specify 2 things: 1) which ops to fuse (via \sa Matcher, and 2) how to create new op(s) from
-/// the existing ops by providing a callback to \p Matcher object
-/// Patterns can be added by using \sa add_matcher
-/// Callbacks should use \sa replace_node to transform matched sub graphs
-
-class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::GraphRewriteBase
+/// Graph rewrite pass is used for matcher passes execution on Function.
+/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass class.
+/// As a default algorithm graph rewrite pass traverse Function in topological order and applies
+/// registered matcher passes for each node. But if all registered matcher passes have type based
+/// root node in Matcher pattern then efficient mechanism is used to execute them.
+/// Matcher pattern root is type based if it's operation from opset or pattern::op::WrapType.
+/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
+/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
+/// efficient.
+
+class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::FunctionPass
{
public:
+ GraphRewrite() = default;
+
+ explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
+ : FunctionPass()
+ {
+ m_matchers.push_back(pass);
+ }
+
+ template <typename T, class... Args>
+ std::shared_ptr<T> add_matcher(Args&&... args)
+ {
+ static_assert(std::is_base_of<pass::MatcherPass, T>::value,
+ "pass not derived from MatcherPass");
+ auto pass = std::make_shared<T>(std::forward<Args>(args)...);
+ m_matchers.push_back(pass);
+ return pass;
+ }
+
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
- const PassPropertyMask& property);
+ const PassPropertyMask& property) NGRAPH_DEPRECATED("Use MatcherPass instead");
- // TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
- const ngraph::graph_rewrite_callback& callback);
+ const ngraph::graph_rewrite_callback& callback)
+ NGRAPH_DEPRECATED("Use MatcherPass instead");
- virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
+ bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
protected:
bool m_enable_shape_inference = false;
+
+ std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
-class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::GraphRewriteBase
+class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::FunctionPass
{
public:
RecurrentGraphRewrite(size_t num_iters = 10)
- : GraphRewriteBase()
+ : FunctionPass()
, m_num_iters(num_iters)
{
}
private:
size_t m_num_iters;
+
+ std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
//*****************************************************************************
#include <algorithm>
-#ifdef _WIN32
-#else
-#include <cxxabi.h>
-#endif
#include <iomanip>
#include <iostream>
#include <memory>
#include "ngraph/env_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
+#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
+#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/serialize.hpp"
static bool profile_enabled = getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
get_state().set_function(func);
- vector<std::pair<shared_ptr<Function>, bool>> fs{std::make_pair(func, func->is_dynamic())};
vector<shared_ptr<Function>> f_array{func};
size_t index = 0;
stopwatch pass_timer;
stopwatch overall_timer;
overall_timer.start();
- for (shared_ptr<PassBase> pass : m_pass_list)
+ bool function_changed = false;
+ for (auto& pass : m_pass_list)
{
pass_timer.start();
pass->set_state(get_state());
- auto module_pass = dynamic_pointer_cast<ModulePass>(pass);
- auto function_pass = dynamic_pointer_cast<FunctionPass>(pass);
- auto node_pass = dynamic_pointer_cast<NodePass>(pass);
- auto call_graph_pass = dynamic_pointer_cast<CallGraphPass>(pass);
- if (module_pass)
+ if (!m_has_default_callback)
+ {
+ pass->set_callback(m_transformation_callback);
+ }
+
+ if (auto module_pass = dynamic_pointer_cast<ModulePass>(pass))
{
if (auto vt_pass = dynamic_pointer_cast<pass::VisualizeTree>(module_pass))
{
vt_pass->set_ops_to_details(get_state().get_visualize_tree_ops_map());
}
- module_pass->run_on_module(f_array);
+ function_changed = module_pass->run_on_module(f_array);
}
- else if (function_pass)
+ else if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass))
{
- for (auto f_pair : fs)
+ // This checks is to skip the graph transformation when the graph pass relies on
+ // static shape but the function state is dynamic.
+ if (matcher_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
+ func->is_dynamic())
{
- shared_ptr<Function> f = f_pair.first;
- // This checks is to skip the graph optimization when the graph pass relies on
- // static shape but the function state is dynamic.
- // we update the function dynamic state only if we run the graph pass successfully.
- if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
- f_pair.second)
- {
- continue;
- }
- bool function_modified = function_pass->run_on_function(f);
- // If the pass may change the function's is_dynamic property, we need to
- // update the cached value.
- if (function_modified &&
- function_pass->get_property(PassProperty::CHANGE_DYNAMIC_STATE))
- {
- f_pair.second = f->is_dynamic();
- }
+ NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+ << "function is dynamic. Skipping this transformation";
+ continue;
}
+ // GraphRewrite is a temporary container for MatcherPass to make execution
+ // on on entire ngraph::Function
+ function_changed = GraphRewrite(matcher_pass).run_on_function(func);
}
- else if (node_pass)
+ else if (auto function_pass = dynamic_pointer_cast<FunctionPass>(pass))
{
- for (auto f_pair : fs)
+ // This checks is to skip the graph transformation when the graph pass relies on
+ // static shape but the function state is dynamic.
+ if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
+ func->is_dynamic())
{
- shared_ptr<Function> f = f_pair.first;
- if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && f_pair.second)
- {
- continue;
- }
- for (shared_ptr<Node> n : f->get_ops())
+ NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+ << "function is dynamic. Skipping this transformation";
+ continue;
+ }
+
+ if (dynamic_pointer_cast<Validate>(pass))
+ {
+ if (function_changed)
{
- node_pass->run_on_node(n);
+ function_pass->run_on_function(func);
+ function_changed = false;
}
}
+ else
+ {
+ function_changed = function_pass->run_on_function(func);
+ }
}
- else if (call_graph_pass)
+ else if (auto node_pass = dynamic_pointer_cast<NodePass>(pass))
{
- for (auto f_pair : fs)
+ if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic())
{
- shared_ptr<Function> f = f_pair.first;
- if (call_graph_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
- f_pair.second)
- {
- continue;
- }
- bool function_modified = call_graph_pass->run_on_call_graph(f->get_ordered_ops());
- f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
+ NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+ << "function is dynamic. Skipping this transformation";
+ continue;
+ }
+ for (shared_ptr<Node> n : func->get_ops())
+ {
+ function_changed |= node_pass->run_on_node(n);
+ }
+ }
+ else if (auto call_graph_pass = dynamic_pointer_cast<CallGraphPass>(pass))
+ {
+ if (call_graph_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
+ func->is_dynamic())
+ {
+ NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
+ << "function is dynamic. Skipping this transformation";
+ continue;
}
+ function_changed = call_graph_pass->run_on_call_graph(func->get_ordered_ops());
}
if (m_visualize || m_serialize)
std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto base_filename = f_array.at(0)->get_name() + std::string("_") + index_str +
- std::string("_") + m_pass_names.at(index);
+ std::string("_") + pass->get_name();
if (m_visualize)
{
pass_timer.stop();
if (profile_enabled)
{
- PassBase* p = pass.get();
- string name = typeid(*p).name();
-#ifndef _WIN32
- int status;
- name = abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status);
-#endif
- cout << setw(7) << pass_timer.get_milliseconds() << "ms " << name << "\n";
+ cout << setw(7) << pass_timer.get_milliseconds() << "ms " << pass->get_name() << "\n";
}
}
if (profile_enabled)
/// each registered pass
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
+ /// \brief Callback is a lambda function that can be used by registered transformations.
+ /// The main purpose of this callback is to provide a way for plugins to disable/enable
+ /// transformations. In some cases plugins may want not to execute some transformations.
+ /// For example plugin can disable unpleasant decompositions because of performance reasons.
+ /// Callback example:
+ /// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
+ /// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
+ /// };
+ /// This callback returns true in case of DepthToSpace operation. So when execution DepthToSpace
+ /// decomposition pass will check is this decomposition needed or plugin can execute this
+ /// operation directly. And of course on transformation side we need to have a response for this
+ /// callback.
+ /// if (m_transformation_callback(batch_to_space)) {
+ /// return false;
+ /// }
+ /// \param callback lamda function that returns true in case if node is supported by plugin and
+ /// transformation is not needed
+ void set_callback(param_callback callback)
+ {
+ m_transformation_callback = callback;
+ m_has_default_callback = false;
+ }
+
private:
template <typename T, class... Args>
std::shared_ptr<T> push_pass(Args&&... args)
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
- if (m_visualize || m_serialize)
- {
-#ifdef _WIN32
- // MSVC produce a human-readable type name like class ngraph::pass::LikeReplacement
- // by typeid(T).name(). Later ofstream doesn't accept it as a valid file name.
- //
- std::string str = typeid(T).name();
- auto pos = str.find_last_of(":");
- m_pass_names.push_back(str.substr(pos + 1));
-#elif defined(__linux) || defined(__APPLE__)
- m_pass_names.push_back(typeid(T).name());
-#endif
- }
return pass;
}
- std::vector<std::string> m_pass_names;
+ param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
+ return false;
+ };
+ bool m_has_default_callback = true;
+
std::vector<std::shared_ptr<PassBase>> m_pass_list;
ManagerState m_state;
PassConfig m_pass_config;
// limitations under the License.
//*****************************************************************************
-#include "ngraph/pass/pass.hpp"
+#ifdef _WIN32
+#else
+#include <cxxabi.h>
+#endif
+
#include "ngraph/pass/manager.hpp"
+#include "ngraph/pass/pass.hpp"
using namespace std;
using namespace ngraph;
}
}
+std::string pass::PassBase::get_name() const
+{
+ if (m_name.empty())
+ {
+ const PassBase* p = this;
+ std::string pass_name = typeid(*p).name();
+#ifndef _WIN32
+ int status;
+ pass_name = abi::__cxa_demangle(pass_name.c_str(), nullptr, nullptr, &status);
+#endif
+ return pass_name;
+ }
+ else
+ {
+ return m_name;
+ }
+}
+
+void pass::PassBase::set_callback(const param_callback& callback)
+{
+ m_transformation_callback = callback;
+ m_has_default_callback = false;
+}
+
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
pass::ModulePass::~ModulePass()
#include <memory>
#include <vector>
+#include "ngraph/deprecated.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
class PassBase;
class ModulePass;
class FunctionPass;
- class NodePass;
- class CallGraphPass;
+ class NodePass NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.");
+ class CallGraphPass NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.");
class Manager;
enum class FusionType : uint32_t
{
// Pass requires node shapes to be static
REQUIRE_STATIC_SHAPE = 0x1,
// Pass transformation will change the function's dynamic state
- CHANGE_DYNAMIC_STATE = 1 << 1
+ CHANGE_DYNAMIC_STATE = 1 << 1,
};
+
+ using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
}
}
/// Check if this pass has all the pass properties.
bool get_property(const PassPropertyMask& prop_mask) const;
+ void set_name(const std::string& name) { m_name = name; }
+ std::string get_name() const;
+
+ void set_callback(const param_callback& callback);
+
protected:
ManagerState& get_state();
void set_state(ManagerState&);
void set_property(const PassPropertyMask& prop, bool value);
+ param_callback m_transformation_callback = [](const std::shared_ptr<const Node>&) -> bool {
+ return false;
+ };
+ bool m_has_default_callback = true;
+
private:
PassPropertyMask m_property;
ManagerState* m_state{nullptr};
+ std::string m_name;
};
class NGRAPH_API ngraph::pass::ModulePass : public PassBase
render();
+ // Clean up local variable not to hold node pointers
+ m_nodes_with_attributes.clear();
+
return false;
}
bool Matcher::match(const Output<Node>& graph_value)
{
- // clear our state
- m_matched_list.clear();
return match(graph_value, PatternValueMap{});
}
bool Matcher::match(const Output<Node>& graph_value,
const PatternValueMap& previous_matches)
{
- // clear our state
- m_match_root.reset();
- m_pattern_map.clear();
- m_matched_list.clear();
+ clear_state();
// insert previous matches
m_pattern_map.insert(previous_matches.cbegin(), previous_matches.cend());
return match(graph_value, as_pattern_value_map(previous_matches));
}
+ void Matcher::clear_state()
+ {
+ m_match_root.reset();
+ m_pattern_map.clear();
+ m_pattern_value_maps.clear();
+ m_matched_list.clear();
+ }
+
namespace
{
std::set<std::shared_ptr<Node>>
void capture(const std::set<Node*>& static_nodes);
+ void clear_state();
+
size_t get_number_of_recurrent_matches() const { return m_pattern_value_maps.size(); }
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
size_t get_number_of_bound_labels() const;
}
return false;
}
+
+std::shared_ptr<Node> pattern::any_input()
+{
+ return std::make_shared<pattern::op::Label>();
+}
\ No newline at end of file
set_output_type(0, type, s);
}
- Label(const element::Type& type, const PartialShape& s)
+ explicit Label(const element::Type& type = element::dynamic,
+ const PartialShape& s = PartialShape::dynamic())
: Label(type, s, [](const Output<Node>&) { return true; }, OutputVector())
{
}
static Output<Node> wrap_values(const OutputVector& wrapped_values);
};
}
+
+ NGRAPH_API
+ std::shared_ptr<Node> any_input();
}
}
}
return result;
}
+
+ std::function<bool(Output<Node>)> consumers_count(size_t n)
+ {
+ return
+ [=](Output<Node> output) -> bool { return output.get_target_inputs().size() == n; };
+ }
}
}
return pred;
}
+ NGRAPH_API
+ std::function<bool(Output<Node>)> consumers_count(size_t n);
+
namespace op
{
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include "ngraph/pattern/op/wrap_type.hpp"
+#include "ngraph/pattern/matcher.hpp"
+
+using namespace std;
+using namespace ngraph;
+
+constexpr NodeTypeInfo pattern::op::WrapType::type_info;
+
+const NodeTypeInfo& pattern::op::WrapType::get_type_info() const
+{
+ return type_info;
+}
+
+bool pattern::op::WrapType::match_value(Matcher* matcher,
+ const Output<Node>& pattern_value,
+ const Output<Node>& graph_value)
+{
+ if (graph_value.get_node_shared_ptr()->get_type_info() == get_wrapped_type() &&
+ m_predicate(graph_value))
+ {
+ auto& pattern_map = matcher->get_pattern_value_map();
+ pattern_map[shared_from_this()] = graph_value;
+ matcher->add_node(graph_value);
+ return (get_input_size() == 0
+ ? true
+ : matcher->match_arguments(pattern_value.get_node(),
+ graph_value.get_node_shared_ptr()));
+ }
+ return false;
+}
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include "ngraph/node.hpp"
+#include "ngraph/pattern/op/pattern.hpp"
+
+namespace ngraph
+{
+ namespace pattern
+ {
+ namespace op
+ {
+ class NGRAPH_API WrapType : public Pattern
+ {
+ public:
+ static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
+ const NodeTypeInfo& get_type_info() const override;
+
+ explicit WrapType(NodeTypeInfo wrapped_type,
+ const ValuePredicate& pred =
+ [](const Output<Node>& output) { return true; },
+ const OutputVector& input_values = {})
+ : Pattern(input_values, pred)
+ , m_wrapped_type(wrapped_type)
+ {
+ set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
+ }
+
+ bool match_value(pattern::Matcher* matcher,
+ const Output<Node>& pattern_value,
+ const Output<Node>& graph_value) override;
+
+ NodeTypeInfo get_wrapped_type() const { return m_wrapped_type; }
+ private:
+ NodeTypeInfo m_wrapped_type;
+ };
+ }
+
+ template <class T>
+ std::shared_ptr<Node> wrap_type(const OutputVector& inputs,
+ const pattern::op::ValuePredicate& pred)
+ {
+ static_assert(std::is_base_of<Node, T>::value, "Unexpected template type");
+ return std::make_shared<op::WrapType>(T::type_info, pred, inputs);
+ }
+
+ template <class T>
+ std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {})
+ {
+ return wrap_type<T>(inputs, [](const Output<Node>& output) { return true; });
+ }
+
+ template <class T>
+ std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred)
+ {
+ return wrap_type<T>({}, pred);
+ }
+ }
+}
eval.cpp
file_util.cpp
float16.cpp
+ graph_rewrite.cpp
includes.cpp
input_output_assign.cpp
intervals.cpp
main.cpp
+ matcher_pass.cpp
misc.cpp
ngraph_api.cpp
node_input_output.cpp
ASSERT_EQ(values_expected, values_out);
}
-TEST(constant_folding, pass_property)
-{
- auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
- ASSERT_FALSE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
- ASSERT_TRUE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
-}
-
TEST(constant_folding, constant_non_zero_0D)
{
auto data = op::Constant::create(element::i32, Shape{}, {1});
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+#include <ngraph/opsets/opset3.hpp>
+#include <ngraph/pass/graph_rewrite.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <util/test_tools.hpp>
+
+using namespace ::testing;
+using namespace std;
+using namespace ngraph;
+
+class TestPass : public ngraph::pass::MatcherPass
+{
+public:
+ TestPass()
+ : MatcherPass()
+ {
+ auto divide = std::make_shared<ngraph::pattern::op::Label>(
+ element::f32, Shape{}, pattern::has_class<opset3::Divide>());
+ ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
+ if (m_transformation_callback(m.get_match_root()))
+ {
+ auto relu =
+ std::make_shared<ngraph::opset3::Relu>(m.get_match_root()->input_value(0));
+ ngraph::replace_node(m.get_match_root(), relu);
+ return true;
+ }
+ return false;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "TestMatcher");
+ this->register_matcher(m, callback);
+ }
+};
+
+class Anchor : public ngraph::pass::GraphRewrite
+{
+public:
+ Anchor()
+ : GraphRewrite()
+ {
+ }
+};
+
+std::shared_ptr<Function> get_function()
+{
+ auto data =
+ std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
+ auto divide_constant =
+ ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
+ auto divide = std::make_shared<ngraph::opset3::Divide>(data, divide_constant);
+ return std::make_shared<ngraph::Function>(ngraph::NodeVector{divide},
+ ngraph::ParameterVector{data});
+}
+
+ngraph::pass::param_callback get_callback()
+{
+ return [](const std::shared_ptr<const Node>& node) -> bool {
+ if (std::dynamic_pointer_cast<const opset3::Divide>(node))
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ };
+}
+
+TEST(GraphRewriteTest, MatcherPassCallback)
+{
+ auto f = get_function();
+
+ Anchor anchor;
+ anchor.add_matcher<TestPass>()->set_callback(get_callback());
+ anchor.run_on_function(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
+TEST(GraphRewriteTest, GraphRewriteCallback)
+{
+ auto f = get_function();
+
+ Anchor anchor;
+ anchor.add_matcher<TestPass>();
+ anchor.set_callback(get_callback());
+ anchor.run_on_function(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
+TEST(GraphRewriteTest, ManagerCallback)
+{
+ auto f = get_function();
+
+ pass::Manager manager;
+ auto anchor = manager.register_pass<Anchor>();
+ anchor->add_matcher<TestPass>();
+ manager.set_callback(get_callback());
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
+
+TEST(GraphRewriteTest, ManagerCallback2)
+{
+ auto f = get_function();
+
+ pass::Manager manager;
+ auto anchor = manager.register_pass<TestPass>();
+ manager.set_callback(get_callback());
+ manager.run_passes(f);
+
+ ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
+}
--- /dev/null
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <algorithm>
+#include <cstdio>
+#include <iostream>
+#include <list>
+#include <memory>
+#include <ngraph/pattern/op/wrap_type.hpp>
+#include <ngraph/rt_info.hpp>
+
+#include "gtest/gtest.h"
+#include "ngraph/graph_util.hpp"
+#include "ngraph/log.hpp"
+#include "ngraph/ngraph.hpp"
+#include "ngraph/opsets/opset3.hpp"
+#include "ngraph/pass/graph_rewrite.hpp"
+#include "ngraph/pass/manager.hpp"
+
+using namespace ngraph;
+using namespace std;
+
+class TestMatcherPass : public pass::MatcherPass
+{
+public:
+ TestMatcherPass()
+ {
+ auto m_relu1 =
+ ngraph::pattern::wrap_type<ngraph::opset3::Relu>(pattern::consumers_count(1));
+ auto m_relu2 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>({m_relu1});
+
+ ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
+ // Map that helps to connect labels with matched outputs
+ auto& node_to_output = m.get_pattern_value_map();
+
+ // Create new Relu operation and add register it for additional execution
+ auto new_relu = register_new_node<ngraph::opset3::Relu>(
+ node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));
+
+ // Copy runtime info attributes to newly created operation
+ ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu);
+
+ // Save last Relu name to new Relu operation
+ new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());
+
+ // Replace Relu->Relu with Relu
+ ngraph::replace_node(m.get_match_root(), new_relu);
+
+ // Return true as the root node was changed
+ return true;
+ };
+
+ // Register pattern with Divide operation as a pattern root node
+ auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
+ // Register Matcher
+ this->register_matcher(m, callback);
+ }
+};
+
+TEST(pattern, matcher_pass)
+{
+ {
+ TestMatcherPass test_matcher;
+ auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
+ auto b = make_shared<opset3::Relu>(a);
+ auto c = make_shared<opset3::Relu>(b);
+ auto f = std::make_shared<Function>(ngraph::NodeVector{c}, ParameterVector{a});
+
+ ASSERT_TRUE(test_matcher.get_matcher()->match(c->output(0)));
+ ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().size() == 2);
+ test_matcher.get_matcher()->clear_state();
+ ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().empty());
+
+ test_matcher.apply(c);
+ ASSERT_TRUE(test_matcher.get_new_nodes().size() == 1);
+ test_matcher.apply(test_matcher.get_new_nodes()[0]);
+ ASSERT_TRUE(test_matcher.get_new_nodes().empty());
+ }
+
+ {
+ TestMatcherPass test_matcher;
+ auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
+ auto b = make_shared<opset3::Relu>(a);
+ auto c = make_shared<opset3::Relu>(b);
+ auto f = std::make_shared<Function>(ngraph::NodeVector{b, c}, ParameterVector{a});
+
+ ASSERT_FALSE(test_matcher.get_matcher()->match(c->output(0)));
+ }
+
+ {
+ std::shared_ptr<Function> f;
+ {
+ auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
+ auto b = make_shared<opset3::Relu>(a);
+ auto c = make_shared<opset3::Relu>(b);
+ auto d = make_shared<opset3::Relu>(c);
+ f = std::make_shared<Function>(ngraph::NodeVector{d}, ParameterVector{a});
+ }
+
+ pass::GraphRewrite pass;
+ pass.add_matcher<TestMatcherPass>();
+ pass.run_on_function(f);
+
+ // Parameter->Relu->Result
+ ASSERT_TRUE(f->get_ops().size() == 3);
+ }
+}
\ No newline at end of file
#include <iostream>
#include <list>
#include <memory>
+#include <ngraph/pattern/op/wrap_type.hpp>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm.match(add3));
- ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
+ ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
auto padd2 = iconst_label + rpattern;
RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm2.match(add3_2));
- ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
+ ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2_2);
ASSERT_EQ(recurrent_matches.at(1), add1);
correlated_matches.insert(iconst_label);
RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
ASSERT_TRUE(rm3.match(add3_2));
- ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
+ ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.size(), 1);
ASSERT_EQ(iconst_matches.at(0), iconst0);
// Matching correlated labels and
// testing if RecurrentMatcher can be reused for different nodes
ASSERT_TRUE(rm3.match(add3));
- ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
+ ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_TRUE(n.match(label_abs2, absn2));
ASSERT_FALSE(n.is_contained_match());
}
+
+TEST(pattern, wrap_type)
+{
+ auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
+ auto b = make_shared<op::Abs>(a);
+ auto c = make_shared<op::Relu>(a);
+ auto mul1 = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
+ auto mul2 = make_shared<op::v1::Multiply>(op::Constant::create(element::f32, Shape{}, {1}), a);
+
+ {
+ auto m = pattern::wrap_type<op::Abs>();
+ auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
+ ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
+ ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
+ ASSERT_EQ(matcher->get_matched_nodes()[0], b);
+ ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
+ ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
+ }
+ {
+ auto m1 = pattern::wrap_type<op::Parameter>();
+ auto m2 = pattern::wrap_type<op::Abs>({m1});
+ auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
+ ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
+ ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
+ ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
+ ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
+ ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
+ }
+ {
+ auto m1 = pattern::wrap_type<op::v1::Multiply>(
+ {pattern::any_input(), pattern::wrap_type<op::Constant>()});
+ auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
+ ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
+ ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
+ }
+ {
+ auto m1 = pattern::wrap_type<op::v1::Multiply>(
+ {pattern::wrap_type<op::Constant>(), pattern::any_input()});
+ auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
+ ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
+ ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
+ }
+}
\ No newline at end of file